/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.nn.Activation;
import ai.djl.training.loss.Loss;

public class SigmoidBinaryCrossEntropyLoss
extends Loss {
    private float weight;
    private int batchAxis;
    private boolean fromSigmoid;

    public SigmoidBinaryCrossEntropyLoss(float weight, int batchAxis, boolean fromSigmoid) {
        super("SigmoidBinaryCrossEntropyLoss");
        this.weight = weight;
        this.batchAxis = batchAxis;
        this.fromSigmoid = fromSigmoid;
    }

    public SigmoidBinaryCrossEntropyLoss() {
        this(1.0f, 0, false);
    }

    @Override
    public NDArray getLoss(NDList label, NDList prediction) {
        NDArray loss;
        NDArray pred = prediction.singletonOrThrow();
        NDArray lab = label.singletonOrThrow();
        lab = lab.reshape(pred.getShape());
        if (!this.fromSigmoid) {
            loss = Activation.relu(pred).sub(pred.mul(lab)).add(Activation.softrelu(pred.abs().neg()));
        } else {
            double eps = 1.0E-12;
            loss = pred.add(eps).log().mul(lab).add(NDArrays.sub(1.0, pred).add(eps).mul(NDArrays.sub(1.0, lab)));
        }
        if (this.weight != 1.0f) {
            loss = loss.mul(Float.valueOf(this.weight));
        }
        return loss.mean(this.excludeBatchAxis(loss, this.batchAxis));
    }
}

