/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.modality.cv.MultiBoxTarget;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.loss.Loss;

public class SingleShotDetectionLoss
extends Loss {
    private Loss softmaxLoss = Loss.softmaxCrossEntropyLoss();
    private Loss l1Loss = Loss.l1Loss();
    private MultiBoxTarget multiBoxTarget = new MultiBoxTarget.Builder().build();

    public SingleShotDetectionLoss(String name) {
        super(name);
    }

    @Override
    public NDArray getLoss(NDList labels, NDList predictions) {
        NDArray anchors = (NDArray)predictions.get(0);
        NDArray classPredictions = (NDArray)predictions.get(1);
        NDArray boundingBoxPredictions = (NDArray)predictions.get(2);
        NDList targets = this.multiBoxTarget.target(new NDList(anchors, labels.head(), classPredictions.transpose(0, 2, 1)));
        NDArray boundingBoxLabels = (NDArray)targets.get(0);
        NDArray boundingBoxMasks = (NDArray)targets.get(1);
        NDArray classLabels = (NDArray)targets.get(2);
        NDArray classLoss = this.softmaxLoss.getLoss(new NDList(classLabels), new NDList(classPredictions));
        NDArray boundingBoxLoss = this.l1Loss.getLoss(new NDList(boundingBoxLabels.mul(boundingBoxMasks)), new NDList(boundingBoxPredictions.mul(boundingBoxMasks)));
        return classLoss.add(boundingBoxLoss);
    }
}

