/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.loss.Loss;

public class SoftmaxCrossEntropyLoss
extends Loss {
    private float weight;
    private int batchAxis;
    private int classAxis;
    private boolean sparseLabel;
    private boolean fromLogit;

    public SoftmaxCrossEntropyLoss(float weight, int batchAxis, int classAxis, boolean sparseLabel, boolean fromLogit) {
        super("SoftmaxCrossEntropyLoss");
        this.weight = weight;
        this.batchAxis = batchAxis;
        this.classAxis = classAxis;
        this.sparseLabel = sparseLabel;
        this.fromLogit = fromLogit;
    }

    public SoftmaxCrossEntropyLoss() {
        this(1.0f, 0, -1, true, false);
    }

    @Override
    public NDArray getLoss(NDList label, NDList prediction) {
        NDArray loss;
        NDArray pred = prediction.singletonOrThrow();
        if (!this.fromLogit) {
            pred = pred.softmax(this.classAxis).log();
        }
        NDArray lab = label.singletonOrThrow();
        if (this.sparseLabel) {
            loss = pred.getNDArrayInternal().pick(lab, this.classAxis, true).neg();
        } else {
            lab = lab.reshape(pred.getShape());
            loss = pred.mul(lab).sum(new int[]{this.classAxis});
        }
        if (this.weight != 1.0f) {
            loss = loss.mul(Float.valueOf(this.weight));
        }
        return loss.mean(this.excludeBatchAxis(loss, this.batchAxis));
    }
}

