/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.listeners.records;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class LossCurve {
    private List<String> lossNames;
    private INDArray lossValues;

    public LossCurve(List<Loss> losses) {
        this.lossNames = Collections.unmodifiableList(losses.get(0).getLossNames());
        int numLossValues = losses.get(0).lossValues().length;
        this.lossValues = Nd4j.create(DataType.FLOAT, losses.size(), losses.get(0).lossValues().length);
        for (int i = 0; i < losses.size(); ++i) {
            Loss l = losses.get(i);
            Preconditions.checkArgument((boolean)l.getLossNames().equals(this.lossNames), (String)"Loss names for loss %s differ from others.  Expected %s, got %s", (Object)i, this.lossNames, l.getLossNames());
            Preconditions.checkArgument((l.getLosses().length == numLossValues ? 1 : 0) != 0, (String)"Number of loss values for loss %s differ from others.  Expected %s, got %s", (int)i, (int)numLossValues, (int)l.getLosses().length);
            this.lossValues = this.lossValues.putRow(i, Nd4j.createFromArray(l.getLosses()).castTo(DataType.FLOAT));
        }
    }

    public LossCurve(double[] lossValues, List<String> lossNames) {
        this.lossValues = Nd4j.createFromArray((double[][])new double[][]{lossValues}).castTo(DataType.FLOAT);
        this.lossNames = lossNames;
    }

    protected LossCurve(INDArray lossValues, List<String> lossNames) {
        Preconditions.checkArgument((lossValues.rank() == 2 ? 1 : 0) != 0, (String)"lossValues must have a rank of 2, got %s", (int)lossValues.rank());
        Preconditions.checkArgument((lossValues.dataType() == DataType.FLOAT ? 1 : 0) != 0, (String)"lossValues must be type FLOAT, got %s", (Object)lossValues.dataType());
        this.lossValues = lossValues;
        this.lossNames = lossNames;
    }

    public List<Loss> losses() {
        ArrayList<Loss> losses = new ArrayList<Loss>();
        int i = 0;
        while ((long)i < this.lossValues.size(0)) {
            losses.add(new Loss(this.lossNames, this.lossValues.getRow(i).toDoubleVector()));
            ++i;
        }
        return losses;
    }

    public Loss meanLoss(int epoch) {
        if (epoch >= 0) {
            return new Loss(this.lossNames, this.lossValues.getRow(epoch).toDoubleVector());
        }
        return new Loss(this.lossNames, this.lossValues.getRow(this.lossValues.rows() + epoch).toDoubleVector());
    }

    public Loss lastMeanLoss() {
        return this.meanLoss(-1);
    }

    public float[] meanLoss(@NonNull String lossName) {
        if (lossName == null) {
            throw new NullPointerException("lossName is marked @NonNull but is null");
        }
        int idx = this.lossNames.indexOf(lossName);
        Preconditions.checkArgument((idx >= 0 ? 1 : 0) != 0, (String)"No loss value for %s.  Existing losses: %s", (Object)lossName, this.lossNames);
        float[] loss = new float[(int)this.lossValues.size(0)];
        int i = 0;
        while ((long)i < this.lossValues.size(0)) {
            loss[i] = this.lossValues.getFloat(i, idx);
            ++i;
        }
        return loss;
    }

    public float[] meanLoss(@NonNull SDVariable loss) {
        if (loss == null) {
            throw new NullPointerException("loss is marked @NonNull but is null");
        }
        return this.meanLoss(loss.getVarName());
    }

    public float meanLoss(@NonNull String lossName, int epoch) {
        if (lossName == null) {
            throw new NullPointerException("lossName is marked @NonNull but is null");
        }
        int idx = this.lossNames.indexOf(lossName);
        Preconditions.checkArgument((idx >= 0 ? 1 : 0) != 0, (String)"No loss value for %s.  Existing losses: %s", (Object)lossName, this.lossNames);
        if (epoch >= 0) {
            return this.lossValues.getFloat(epoch, idx);
        }
        return this.lossValues.getFloat(this.lossValues.rows() + epoch, idx);
    }

    public float meanLoss(@NonNull SDVariable loss, int epoch) {
        if (loss == null) {
            throw new NullPointerException("loss is marked @NonNull but is null");
        }
        return this.meanLoss(loss.getVarName(), epoch);
    }

    public float lastMeanLoss(@NonNull String lossName) {
        if (lossName == null) {
            throw new NullPointerException("lossName is marked @NonNull but is null");
        }
        int idx = this.lossNames.indexOf(lossName);
        Preconditions.checkArgument((idx >= 0 ? 1 : 0) != 0, (String)"No loss value for %s.  Existing losses: %s", (Object)lossName, this.lossNames);
        return this.lossValues.getFloat(this.lossValues.rows() - 1, idx);
    }

    public float lastMeanLoss(@NonNull SDVariable loss) {
        if (loss == null) {
            throw new NullPointerException("loss is marked @NonNull but is null");
        }
        return this.lastMeanLoss(loss.getVarName());
    }

    public Loss lastMeanDelta() {
        return this.lastMeanLoss().sub(this.meanLoss(-2));
    }

    public double lastMeanDelta(String lossName) {
        return this.lastMeanDelta().getLoss(lossName);
    }

    public double lastMeanDelta(SDVariable loss) {
        return this.lastMeanDelta(loss.getVarName());
    }

    public LossCurve addLossAndCopy(Loss loss) {
        return this.addLossAndCopy(loss.getLosses(), loss.lossNames());
    }

    public LossCurve addLossAndCopy(double[] values, List<String> lossNames) {
        return new LossCurve(Nd4j.concat(0, this.lossValues, Nd4j.createFromArray((double[][])new double[][]{values}).castTo(DataType.FLOAT)), lossNames);
    }

    public List<String> getLossNames() {
        return this.lossNames;
    }

    public INDArray getLossValues() {
        return this.lossValues;
    }
}

