/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.loss.HingeLoss;
import ai.djl.training.loss.L1Loss;
import ai.djl.training.loss.L2Loss;
import ai.djl.training.loss.SigmoidBinaryCrossEntropyLoss;
import ai.djl.training.loss.SoftmaxCrossEntropyLoss;
import ai.djl.training.metrics.TrainingMetric;
import java.util.stream.IntStream;

public abstract class Loss
extends TrainingMetric {
    private float totalLoss;
    private int totalInstances;

    public Loss(String name) {
        super(name);
    }

    public abstract NDArray getLoss(NDList var1, NDList var2);

    public static L1Loss l1Loss() {
        return new L1Loss();
    }

    public static L1Loss l1Loss(float weight, int batchAxis) {
        return new L1Loss(weight, batchAxis);
    }

    public static L2Loss l2Loss() {
        return new L2Loss();
    }

    public static L2Loss l2Loss(float weight, int batchAxis) {
        return new L2Loss(weight, batchAxis);
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss() {
        return new SigmoidBinaryCrossEntropyLoss();
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(float weight, int batchAxis, boolean fromSigmoid) {
        return new SigmoidBinaryCrossEntropyLoss(weight, batchAxis, fromSigmoid);
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss() {
        return new SoftmaxCrossEntropyLoss();
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(float weight, int batchAxis, int classAxis, boolean sparseLabel, boolean fromLogit) {
        return new SoftmaxCrossEntropyLoss(weight, batchAxis, classAxis, sparseLabel, fromLogit);
    }

    public static HingeLoss hingeLoss() {
        return new HingeLoss();
    }

    public static HingeLoss hingeLoss(int margin, float weight, int batchAxis) {
        return new HingeLoss(margin, weight, batchAxis);
    }

    @Override
    public Loss duplicate() {
        try {
            return (Loss)this.clone();
        }
        catch (CloneNotSupportedException e) {
            throw new AssertionError("Clone is not supported", e);
        }
    }

    @Override
    public void update(NDList labels, NDList predictions) {
        NDArray update = this.getLoss(labels, predictions);
        this.totalLoss += update.sum().getFloat(new long[0]);
        this.totalInstances = (int)((long)this.totalInstances + update.size());
    }

    @Override
    public void reset() {
        this.totalLoss = 0.0f;
        this.totalInstances = 0;
    }

    @Override
    public float getValue() {
        if (this.totalInstances == 0) {
            return Float.NaN;
        }
        return this.totalLoss / (float)this.totalInstances;
    }

    protected int[] excludeBatchAxis(NDArray loss, int batchAxis) {
        return IntStream.range(0, loss.getShape().dimension()).filter(axis -> axis != batchAxis).toArray();
    }
}

