/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.nn.Activation;
import ai.djl.training.loss.Loss;

public class HingeLoss
extends Loss {
    private int margin;
    private float weight;
    private int batchAxis;

    public HingeLoss(int margin, float weight, int batchAxis) {
        super("HingeLoss");
        this.margin = margin;
        this.weight = weight;
        this.batchAxis = batchAxis;
    }

    public HingeLoss() {
        this(1, 1.0f, 0);
    }

    @Override
    public NDArray getLoss(NDList label, NDList prediction) {
        NDArray pred = prediction.singletonOrThrow();
        NDArray labelReshaped = label.singletonOrThrow().reshape(pred.getShape());
        NDArray loss = Activation.relu(NDArrays.sub(this.margin, labelReshaped.mul(pred)));
        if (this.weight != 1.0f) {
            loss = loss.mul(Float.valueOf(this.weight));
        }
        return loss.mean(this.excludeBatchAxis(loss, this.batchAxis));
    }
}

