/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.evaluator;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.training.evaluator.AbstractAccuracy;
import ai.djl.util.Pair;
import ai.djl.util.Preconditions;

public class BinaryAccuracy
extends AbstractAccuracy {
    float threshold;

    public BinaryAccuracy(String name, float threshold, int index, int axis) {
        super(name, index, axis);
        this.threshold = threshold;
    }

    public BinaryAccuracy(String name, float threshold, int index) {
        this(name, threshold, index, 1);
    }

    public BinaryAccuracy(float threshold) {
        this("BinaryAccuracy", threshold, 0, 1);
    }

    public BinaryAccuracy() {
        this(0.0f);
    }

    @Override
    protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
        Preconditions.checkArgument(labels.size() == predictions.size(), "labels and prediction length does not match.");
        NDArray label = (NDArray)labels.get(this.index);
        NDArray prediction = (NDArray)predictions.get(this.index);
        this.checkLabelShapes(label, prediction, false);
        NDArray predictionReduced = prediction.gte(Float.valueOf(this.threshold));
        long total = label.size();
        NDArray correct = label.toType(DataType.INT64, false).eq(predictionReduced.toType(DataType.INT64, false)).countNonzero();
        return new Pair<Long, NDArray>(total, correct);
    }
}

