/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.evaluator;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.training.evaluator.AbstractAccuracy;
import ai.djl.util.Pair;
import java.util.stream.IntStream;

public class TopKAccuracy
extends AbstractAccuracy {
    private int topK;

    public TopKAccuracy(String name, int index, int topK) {
        super(name, index);
        if (topK <= 1) {
            throw new IllegalArgumentException("Please use TopKAccuracy with topK more than 1");
        }
        this.topK = topK;
    }

    public TopKAccuracy(int index, int topK) {
        this("Top_" + topK + "_Accuracy", index, topK);
    }

    public TopKAccuracy(int topK) {
        this("Top_" + topK + "_Accuracy", 0, topK);
    }

    @Override
    protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
        NDArray numCorrect;
        NDArray label = (NDArray)labels.get(this.index);
        NDArray prediction = (NDArray)predictions.get(this.index);
        this.checkLabelShapes(label, prediction);
        NDArray topKPrediction = prediction.argSort(this.axis).toType(DataType.INT32, false);
        int numDims = topKPrediction.getShape().dimension();
        if (numDims == 1) {
            numCorrect = topKPrediction.flatten().eq(label.flatten()).countNonzero();
        } else if (numDims == 2) {
            int numClasses = (int)topKPrediction.getShape().get(1);
            this.topK = Math.min(this.topK, numClasses);
            numCorrect = NDArrays.add((NDArray[])IntStream.range(0, this.topK).mapToObj(j -> {
                NDArray jPrediction = topKPrediction.get(":, " + (numClasses - j - 1));
                return jPrediction.flatten().eq(label.flatten()).countNonzero();
            }).toArray(NDArray[]::new));
        } else {
            throw new IllegalArgumentException("Prediction should be less than 2 dimensions");
        }
        long total = label.getShape().get(0);
        return new Pair<Long, NDArray>(total, numCorrect);
    }
}

