/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.metrics;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.training.metrics.TrainingMetric;

public class Accuracy
extends TrainingMetric {
    private long correctInstances;
    private long totalInstances;
    protected int axis;
    protected int index;

    public Accuracy(String name, int index, int axis) {
        super(name);
        this.axis = axis;
        this.index = index;
    }

    public Accuracy() {
        this("Accuracy", 0, 1);
    }

    public Accuracy(String name, int index) {
        this(name, index, 1);
    }

    @Override
    public void reset() {
        this.correctInstances = 0L;
        this.totalInstances = 0L;
    }

    public void update(NDArray labels, NDArray predictions) {
        this.checkLabelShapes(labels, predictions);
        NDArray predictionReduced = !labels.getShape().equals(predictions.getShape()) ? predictions.argMax(this.axis) : predictions;
        long numCorrect = labels.asType(DataType.INT64, false).eq(predictionReduced.asType(DataType.INT64, false)).countNonzero().getLong(new long[0]);
        this.addCorrectInstances(numCorrect);
        this.addTotalInstances(labels.size());
    }

    @Override
    public void update(NDList labels, NDList predictions) {
        if (labels.size() != predictions.size()) {
            throw new IllegalArgumentException("labels and prediction length does not match.");
        }
        this.update((NDArray)labels.get(this.index), (NDArray)predictions.get(this.index));
    }

    @Override
    public float getValue() {
        if (this.totalInstances == 0L) {
            return Float.NaN;
        }
        return (float)this.correctInstances / (float)this.totalInstances;
    }

    public void addCorrectInstances(long numInstances) {
        this.correctInstances += numInstances;
    }

    public void addTotalInstances(long totalInstances) {
        this.totalInstances += totalInstances;
    }
}

