/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.loss.AbstractCompositeLoss;
import ai.djl.training.loss.Loss;
import ai.djl.util.Pair;
import java.util.ArrayList;
import java.util.List;

public class SimpleCompositeLoss
extends AbstractCompositeLoss {
    private List<Integer> indices;

    public SimpleCompositeLoss() {
        this("CompositeLoss");
    }

    public SimpleCompositeLoss(String name) {
        super(name);
        this.components = new ArrayList();
        this.indices = new ArrayList<Integer>();
    }

    public SimpleCompositeLoss addLoss(Loss loss) {
        this.components.add(loss);
        this.indices.add(null);
        return this;
    }

    public SimpleCompositeLoss addLoss(Loss loss, int index) {
        this.components.add(loss);
        this.indices.add(index);
        return this;
    }

    @Override
    protected Pair<NDList, NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions) {
        if (this.indices.get(componentIndex) != null) {
            int index = this.indices.get(componentIndex);
            return new Pair<NDList, NDList>(new NDList((NDArray)labels.get(index)), new NDList((NDArray)predictions.get(index)));
        }
        return new Pair<NDList, NDList>(labels, predictions);
    }
}

