/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.simple.multiclass;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;

public class RankClassificationResult
implements Serializable {
    private int[][] rankedIndices;
    private float[][] probabilities;
    private List<String> labels;
    private List<String> maxLabels;

    public RankClassificationResult(INDArray outcome) {
        this(outcome, null);
    }

    public RankClassificationResult(INDArray outcome, List<String> labels) {
        int i;
        if (outcome.rank() > 2) {
            throw new ND4JIllegalStateException("Only works with vectors and matrices right now");
        }
        INDArray[] maxWithIndices = Nd4j.sortWithIndices((INDArray)outcome, (int)-1, (boolean)false);
        INDArray indexes = maxWithIndices[0];
        if (labels == null) {
            this.labels = new ArrayList<String>(outcome.columns());
            for (i = 0; i < outcome.columns(); ++i) {
                this.labels.add(String.valueOf(i));
            }
        } else {
            this.labels = new ArrayList<String>(labels);
        }
        this.rankedIndices = new int[indexes.rows()][indexes.columns()];
        this.probabilities = new float[outcome.rows()][outcome.columns()];
        for (i = 0; i < indexes.rows(); ++i) {
            for (int j = 0; j < indexes.columns(); ++j) {
                this.rankedIndices[i][j] = indexes.getInt(new int[]{i, j});
                this.probabilities[i][j] = outcome.getFloat(new int[]{i, j});
            }
        }
        this.maxOutcomes();
    }

    public String maxOutcomeForRow(int r) {
        return this.labels.get(this.rankedIndices[r][0]);
    }

    public List<String> maxOutcomes() {
        if (this.maxLabels == null) {
            this.maxLabels = new ArrayList<String>(this.rankedIndices.length);
            for (int i = 0; i < this.rankedIndices.length; ++i) {
                this.maxLabels.add(this.maxOutcomeForRow(i));
            }
            return this.maxLabels;
        }
        return this.maxLabels;
    }

    public int[][] getRankedIndices() {
        return this.rankedIndices;
    }

    public float[][] getProbabilities() {
        return this.probabilities;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public List<String> getMaxLabels() {
        return this.maxLabels;
    }

    public void setRankedIndices(int[][] rankedIndices) {
        this.rankedIndices = rankedIndices;
    }

    public void setProbabilities(float[][] probabilities) {
        this.probabilities = probabilities;
    }

    public void setLabels(List<String> labels) {
        this.labels = labels;
    }

    public void setMaxLabels(List<String> maxLabels) {
        this.maxLabels = maxLabels;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof RankClassificationResult)) {
            return false;
        }
        RankClassificationResult other = (RankClassificationResult)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!Arrays.deepEquals((Object[])this.getRankedIndices(), (Object[])other.getRankedIndices())) {
            return false;
        }
        if (!Arrays.deepEquals((Object[])this.getProbabilities(), (Object[])other.getProbabilities())) {
            return false;
        }
        List<String> this$labels = this.getLabels();
        List<String> other$labels = other.getLabels();
        if (this$labels == null ? other$labels != null : !((Object)this$labels).equals(other$labels)) {
            return false;
        }
        List<String> this$maxLabels = this.getMaxLabels();
        List<String> other$maxLabels = other.getMaxLabels();
        return !(this$maxLabels == null ? other$maxLabels != null : !((Object)this$maxLabels).equals(other$maxLabels));
    }

    protected boolean canEqual(Object other) {
        return other instanceof RankClassificationResult;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + Arrays.deepHashCode((Object[])this.getRankedIndices());
        result = result * 59 + Arrays.deepHashCode((Object[])this.getProbabilities());
        List<String> $labels = this.getLabels();
        result = result * 59 + ($labels == null ? 43 : ((Object)$labels).hashCode());
        List<String> $maxLabels = this.getMaxLabels();
        result = result * 59 + ($maxLabels == null ? 43 : ((Object)$maxLabels).hashCode());
        return result;
    }

    public String toString() {
        return "RankClassificationResult(rankedIndices=" + Arrays.deepToString((Object[])this.getRankedIndices()) + ", probabilities=" + Arrays.deepToString((Object[])this.getProbabilities()) + ", labels=" + this.getLabels() + ", maxLabels=" + this.getMaxLabels() + ")";
    }
}

