/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.zoo.util;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
import org.deeplearning4j.zoo.util.ClassPrediction;
import org.deeplearning4j.zoo.util.Labels;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public abstract class BaseLabels
implements Labels {
    protected ArrayList<String> labels = null;

    protected BaseLabels() throws IOException {
        this.labels = this.getLabels();
    }

    protected BaseLabels(String textResource) throws IOException {
        this.labels = this.getLabels(textResource);
    }

    protected ArrayList<String> getLabels() throws IOException {
        return null;
    }

    protected ArrayList<String> getLabels(String textResource) throws IOException {
        ArrayList<String> labels = new ArrayList<String>();
        try (Scanner s = new Scanner(this.getClass().getResourceAsStream(textResource));){
            while (s.hasNextLine()) {
                labels.add(s.nextLine());
            }
        }
        return labels;
    }

    @Override
    public String getLabel(int n) {
        return this.labels.get(n);
    }

    @Override
    public List<List<ClassPrediction>> decodePredictions(INDArray predictions, int n) {
        int rows = predictions.size(0);
        int cols = predictions.size(1);
        if (predictions.isColumnVectorOrScalar()) {
            predictions = predictions.ravel();
            rows = predictions.size(0);
            cols = predictions.size(1);
        }
        ArrayList<List<ClassPrediction>> descriptions = new ArrayList<List<ClassPrediction>>();
        for (int batch = 0; batch < rows; ++batch) {
            INDArray result = predictions.getRow(batch);
            result = Nd4j.vstack((INDArray[])new INDArray[]{Nd4j.linspace((int)0, (int)cols, (int)cols), result});
            result = Nd4j.sortColumns((INDArray)result, (int)1, (boolean)false);
            ArrayList<ClassPrediction> current = new ArrayList<ClassPrediction>();
            for (int i = 0; i < n; ++i) {
                int label = result.getInt(new int[]{0, i});
                double prob = result.getDouble(1, i);
                current.add(new ClassPrediction(label, this.getLabel(label), prob));
            }
            descriptions.add(current);
        }
        return descriptions;
    }
}

