/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.Abs;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class RegressionEvaluation {
    public static final int DEFAULT_PRECISION = 5;
    private List<String> columnNames;
    private int precision;
    private int exampleCount = 0;
    private INDArray labelsSumPerColumn;
    private INDArray sumSquaredErrorsPerColumn;
    private INDArray sumAbsErrorsPerColumn;
    private INDArray currentMean;
    private INDArray currentPredictionMean;
    private INDArray m2Actual;
    private INDArray sumOfProducts;
    private INDArray sumSquaredLabels;
    private INDArray sumSquaredPredicted;

    public RegressionEvaluation(int nColumns) {
        this(RegressionEvaluation.createDefaultColumnNames(nColumns), 5);
    }

    public RegressionEvaluation(int nColumns, int precision) {
        this(RegressionEvaluation.createDefaultColumnNames(nColumns), precision);
    }

    public RegressionEvaluation(String ... columnNames) {
        this(Arrays.asList(columnNames), 5);
    }

    public RegressionEvaluation(List<String> columnNames) {
        this(columnNames, 5);
    }

    public RegressionEvaluation(List<String> columnNames, int precision) {
        this.columnNames = columnNames;
        this.precision = precision;
        int n = columnNames.size();
        this.labelsSumPerColumn = Nd4j.zeros((int)n);
        this.sumSquaredErrorsPerColumn = Nd4j.zeros((int)n);
        this.sumAbsErrorsPerColumn = Nd4j.zeros((int)n);
        this.currentMean = Nd4j.zeros((int)n);
        this.m2Actual = Nd4j.zeros((int)n);
        this.currentPredictionMean = Nd4j.zeros((int)n);
        this.sumOfProducts = Nd4j.zeros((int)n);
        this.sumSquaredLabels = Nd4j.zeros((int)n);
        this.sumSquaredPredicted = Nd4j.zeros((int)n);
    }

    private static List<String> createDefaultColumnNames(int nColumns) {
        ArrayList<String> list = new ArrayList<String>(nColumns);
        for (int i = 0; i < nColumns; ++i) {
            list.add("col_" + i);
        }
        return list;
    }

    public void eval(INDArray labels, INDArray predictions) {
        this.labelsSumPerColumn.addi(labels.sum(new int[]{0}));
        INDArray error = predictions.sub(labels);
        INDArray absErrorSum = Nd4j.getExecutioner().execAndReturn((TransformOp)new Abs(error.dup())).sum(new int[]{0});
        INDArray squaredErrorSum = error.mul(error).sum(new int[]{0});
        this.sumAbsErrorsPerColumn.addi(absErrorSum);
        this.sumSquaredErrorsPerColumn.addi(squaredErrorSum);
        this.sumOfProducts.addi(labels.mul(predictions).sum(new int[]{0}));
        this.sumSquaredLabels.addi(labels.mul(labels).sum(new int[]{0}));
        this.sumSquaredPredicted.addi(predictions.mul(predictions).sum(new int[]{0}));
        int nRows = labels.size(0);
        this.currentMean.muli((Number)this.exampleCount).addi(labels.sum(new int[]{0})).divi((Number)(this.exampleCount + nRows));
        this.currentPredictionMean.muli((Number)this.exampleCount).addi(predictions.sum(new int[]{0})).divi((Number)(this.exampleCount + nRows));
        this.exampleCount += nRows;
    }

    public void evalTimeSeries(INDArray labels, INDArray predictions) {
        if (labels.rank() == 2 && predictions.rank() == 2) {
            this.eval(labels, predictions);
        }
        if (labels.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: labels are not rank 3 (rank=" + labels.rank() + ")");
        }
        if (!Arrays.equals(labels.shape(), predictions.shape())) {
            throw new IllegalArgumentException("Labels and predicted have different shapes: labels=" + Arrays.toString(labels.shape()) + ", predictions=" + Arrays.toString(predictions.shape()));
        }
        if (labels.ordering() == 'f') {
            labels = Shape.toOffsetZeroCopy((INDArray)labels, (char)'c');
        }
        if (predictions.ordering() == 'f') {
            predictions = Shape.toOffsetZeroCopy((INDArray)predictions, (char)'c');
        }
        int[] shape = labels.shape();
        labels = labels.permute(new int[]{0, 2, 1});
        labels = labels.reshape(shape[0] * shape[2], shape[1]);
        predictions = predictions.permute(new int[]{0, 2, 1});
        predictions = predictions.reshape(shape[0] * shape[2], shape[1]);
        this.eval(labels, predictions);
    }

    public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray outputMask) {
        int totalOutputExamples = outputMask.sumNumber().intValue();
        int outSize = labels.size(1);
        INDArray labels2d = Nd4j.create((int)totalOutputExamples, (int)outSize);
        INDArray predicted2d = Nd4j.create((int)totalOutputExamples, (int)outSize);
        int rowCount = 0;
        for (int ex = 0; ex < outputMask.size(0); ++ex) {
            for (int t = 0; t < outputMask.size(1); ++t) {
                if (outputMask.getDouble(ex, t) == 0.0) continue;
                labels2d.putRow(rowCount, labels.get(new INDArrayIndex[]{NDArrayIndex.point((int)ex), NDArrayIndex.all(), NDArrayIndex.point((int)t)}));
                predicted2d.putRow(rowCount, predictions.get(new INDArrayIndex[]{NDArrayIndex.point((int)ex), NDArrayIndex.all(), NDArrayIndex.point((int)t)}));
                ++rowCount;
            }
        }
        this.eval(labels2d, predicted2d);
    }

    public String stats() {
        int maxLabelLength = 0;
        for (String s : this.columnNames) {
            maxLabelLength = Math.max(maxLabelLength, s.length());
        }
        int labelWidth = maxLabelLength + 5;
        int columnWidth = this.precision + 10;
        String format = "%-" + labelWidth + "s%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e";
        StringBuilder sb = new StringBuilder();
        String headerFormat = "%-" + labelWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s";
        sb.append(String.format(headerFormat, "Column", "MSE", "MAE", "RMSE", "RSE", "R^2"));
        sb.append("\n");
        for (int i = 0; i < this.columnNames.size(); ++i) {
            double mse = this.meanSquaredError(i);
            double mae = this.meanAbsoluteError(i);
            double rmse = this.rootMeanSquaredError(i);
            double rse = this.relativeSquaredError(i);
            double corr = this.correlationR2(i);
            sb.append(String.format(format, this.columnNames.get(i), mse, mae, rmse, rse, corr));
            sb.append("\n");
        }
        return sb.toString();
    }

    public int numColumns() {
        return this.columnNames.size();
    }

    public double meanSquaredError(int column) {
        return this.sumSquaredErrorsPerColumn.getDouble(column) / (double)this.exampleCount;
    }

    public double meanAbsoluteError(int column) {
        return this.sumAbsErrorsPerColumn.getDouble(column) / (double)this.exampleCount;
    }

    public double rootMeanSquaredError(int column) {
        return Math.sqrt(this.sumSquaredErrorsPerColumn.getDouble(column) / (double)this.exampleCount);
    }

    public double correlationR2(int column) {
        double sumxiyi = this.sumOfProducts.getDouble(column);
        double predictionMean = this.currentPredictionMean.getDouble(column);
        double labelMean = this.currentMean.getDouble(column);
        double sumSquaredLabels = this.sumSquaredLabels.getDouble(column);
        double sumSquaredPredicted = this.sumSquaredPredicted.getDouble(column);
        double r2 = sumxiyi - (double)this.exampleCount * predictionMean * labelMean;
        return r2 /= Math.sqrt(sumSquaredLabels - (double)this.exampleCount * labelMean * labelMean) * Math.sqrt(sumSquaredPredicted - (double)this.exampleCount * predictionMean * predictionMean);
    }

    public double relativeSquaredError(int column) {
        double numerator = this.sumSquaredPredicted.getDouble(column) - 2.0 * this.sumOfProducts.getDouble(column) + this.sumSquaredLabels.getDouble(column);
        double denominator = this.sumSquaredLabels.getDouble(column) - (double)this.exampleCount * this.currentMean.getDouble(column) * this.currentMean.getDouble(column);
        if (Math.abs(denominator) > Nd4j.EPS_THRESHOLD) {
            return numerator / denominator;
        }
        return Double.POSITIVE_INFINITY;
    }
}

