/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.api.util.ndarray;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class RecordConverter {
    private RecordConverter() {
    }

    @Deprecated
    public static INDArray toArray(Collection<Writable> record, int size) {
        return RecordConverter.toArray(record);
    }

    public static List<List<Writable>> toRecords(INDArray matrix) {
        ArrayList<List<Writable>> ret = new ArrayList<List<Writable>>();
        for (int i = 0; i < matrix.rows(); ++i) {
            ret.add(RecordConverter.toRecord(matrix.getRow(i)));
        }
        return ret;
    }

    public static INDArray toMatrix(List<List<Writable>> records) {
        ArrayList<INDArray> toStack = new ArrayList<INDArray>();
        for (List<Writable> l : records) {
            toStack.add(RecordConverter.toArray(l));
        }
        return Nd4j.vstack(toStack);
    }

    /*
     * WARNING - void declaration
     */
    public static INDArray toArray(Collection<? extends Writable> record) {
        ArrayList<? extends Writable> l = record instanceof List ? (ArrayList<? extends Writable>)record : new ArrayList<Writable>(record);
        if (l.size() == 1 && l.get(0) instanceof NDArrayWritable) {
            return ((NDArrayWritable)l.get(0)).get();
        }
        int length = 0;
        for (Writable writable : record) {
            if (writable instanceof NDArrayWritable) {
                INDArray a = ((NDArrayWritable)writable).get();
                if (!a.isRowVector()) {
                    throw new UnsupportedOperationException("Multiple writables present but NDArrayWritable is not a row vector. Can only concat row vectors with other writables. Shape: " + Arrays.toString(a.shape()));
                }
                length += a.length();
                continue;
            }
            ++length;
        }
        INDArray arr = Nd4j.create((int)1, (int)length);
        boolean bl = false;
        for (Writable writable : record) {
            void var4_6;
            if (writable instanceof NDArrayWritable) {
                INDArray toPut = ((NDArrayWritable)writable).get();
                arr.put(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)var4_6, (int)(var4_6 + toPut.length()))}, toPut);
                var4_6 += toPut.length();
                continue;
            }
            arr.putScalar(0, (int)var4_6, writable.toDouble());
            ++var4_6;
        }
        return arr;
    }

    public static List<Writable> toRecord(INDArray array) {
        ArrayList<Writable> writables = new ArrayList<Writable>();
        writables.add(new NDArrayWritable(array));
        return writables;
    }

    public static List<List<Writable>> toRecords(DataSet dataSet) {
        if (RecordConverter.isClassificationDataSet(dataSet)) {
            return RecordConverter.getClassificationWritableMatrix(dataSet);
        }
        return RecordConverter.getRegressionWritableMatrix(dataSet);
    }

    private static boolean isClassificationDataSet(DataSet dataSet) {
        INDArray labels = dataSet.getLabels();
        return labels.sum(new int[]{0, 1}).getInt(new int[]{0}) == dataSet.numExamples() && labels.shape()[1] > 1;
    }

    private static List<List<Writable>> getClassificationWritableMatrix(DataSet dataSet) {
        ArrayList<List<Writable>> writableMatrix = new ArrayList<List<Writable>>();
        for (int i = 0; i < dataSet.numExamples(); ++i) {
            List<Writable> writables = RecordConverter.toRecord(dataSet.getFeatures().getRow(i));
            writables.add(new IntWritable(Nd4j.argMax((INDArray)dataSet.getLabels().getRow(i), (int[])new int[]{1}).getInt(new int[]{0})));
            writableMatrix.add(writables);
        }
        return writableMatrix;
    }

    private static List<List<Writable>> getRegressionWritableMatrix(DataSet dataSet) {
        ArrayList<List<Writable>> writableMatrix = new ArrayList<List<Writable>>();
        for (int i = 0; i < dataSet.numExamples(); ++i) {
            List<Writable> writables = RecordConverter.toRecord(dataSet.getFeatures().getRow(i));
            INDArray labelRow = dataSet.getLabels().getRow(i);
            for (int j = 0; j < labelRow.shape()[1]; ++j) {
                writables.add(new DoubleWritable(labelRow.getDouble(j)));
            }
            writableMatrix.add(writables);
        }
        return writableMatrix;
    }
}

