/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Random;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataSet
implements org.nd4j.linalg.dataset.api.DataSet {
    private static final long serialVersionUID = 1935520764586513365L;
    private static final Logger log = LoggerFactory.getLogger(DataSet.class);
    private static final byte BITMASK_FEATURES_PRESENT = 1;
    private static final byte BITMASK_LABELS_PRESENT = 2;
    private static final byte BITMASK_LABELS_SAME_AS_FEATURES = 4;
    private static final byte BITMASK_FEATURE_MASK_PRESENT = 8;
    private static final byte BITMASK_LABELS_MASK_PRESENT = 16;
    private List<String> columnNames = new ArrayList<String>();
    private List<String> labelNames = new ArrayList<String>();
    private INDArray features;
    private INDArray labels;
    private INDArray featuresMask;
    private INDArray labelsMask;
    private transient boolean preProcessed = false;

    public DataSet() {
        this(null, null);
    }

    public DataSet(INDArray first, INDArray second) {
        this(first, second, null, null);
    }

    public DataSet(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask) {
        this.features = features;
        this.labels = labels;
        this.featuresMask = featuresMask;
        this.labelsMask = labelsMask;
    }

    public boolean isPreProcessed() {
        return this.preProcessed;
    }

    public void markAsPreProcessed() {
        this.preProcessed = true;
    }

    public static DataSet empty() {
        return new DataSet(null, null);
    }

    public static DataSet merge(List<DataSet> data, boolean clone) {
        INDArray labelsMaskOut;
        INDArray labelsOut;
        INDArray[] temp;
        INDArray featuresMaskOut;
        INDArray featuresOut;
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Unable to merge empty dataset");
        }
        DataSet first = data.get(0);
        int rankFeatures = first.getFeatures().rank();
        int rankLabels = first.getLabels().rank();
        INDArray[] featuresToMerge = new INDArray[data.size()];
        INDArray[] labelsToMerge = new INDArray[data.size()];
        int count = 0;
        boolean hasFeaturesMaskArray = false;
        boolean hasLabelsMaskArray = false;
        for (DataSet ds : data) {
            featuresToMerge[count] = ds.getFeatureMatrix();
            labelsToMerge[count++] = ds.getLabels();
            if (rankFeatures != 3 && rankLabels != 3) continue;
            hasFeaturesMaskArray |= ds.getFeaturesMaskArray() != null;
            hasLabelsMaskArray |= ds.getLabelsMaskArray() != null;
        }
        switch (rankFeatures) {
            case 2: {
                featuresOut = DataSet.merge2d(featuresToMerge);
                featuresMaskOut = null;
                break;
            }
            case 3: {
                INDArray[] featuresMasks = null;
                if (hasFeaturesMaskArray) {
                    featuresMasks = new INDArray[featuresToMerge.length];
                    count = 0;
                    for (DataSet ds : data) {
                        featuresMasks[count++] = ds.getFeaturesMaskArray();
                    }
                }
                temp = DataSet.mergeTimeSeries(featuresToMerge, featuresMasks);
                featuresOut = temp[0];
                featuresMaskOut = temp[1];
                break;
            }
            case 4: {
                featuresOut = DataSet.merge4dCnnData(featuresToMerge);
                featuresMaskOut = null;
                break;
            }
            default: {
                throw new IllegalStateException("Cannot merge examples: features rank must be in range 2 to 4 inclusive. First example features shape: " + Arrays.toString(data.get(0).getFeatureMatrix().shape()));
            }
        }
        switch (rankLabels) {
            case 2: {
                labelsOut = DataSet.merge2d(labelsToMerge);
                labelsMaskOut = null;
                break;
            }
            case 3: {
                INDArray[] labelsMasks = null;
                if (hasLabelsMaskArray) {
                    labelsMasks = new INDArray[labelsToMerge.length];
                    count = 0;
                    for (DataSet ds : data) {
                        labelsMasks[count++] = ds.getLabelsMaskArray();
                    }
                }
                temp = DataSet.mergeTimeSeries(labelsToMerge, labelsMasks);
                labelsOut = temp[0];
                labelsMaskOut = temp[1];
                break;
            }
            case 4: {
                labelsOut = DataSet.merge4dCnnData(featuresToMerge);
                labelsMaskOut = null;
                break;
            }
            default: {
                throw new IllegalStateException("Cannot merge examples: labels rank must be in range 2 to 4 inclusive. First example labels shape: " + Arrays.toString(data.get(0).getLabels().shape()));
            }
        }
        return new DataSet(featuresOut, labelsOut, featuresMaskOut, labelsMaskOut);
    }

    private static INDArray merge2d(INDArray[] data) {
        if (data.length == 0) {
            return data[0];
        }
        int totalRows = 0;
        for (INDArray arr : data) {
            totalRows += arr.rows();
        }
        INDArray out = Nd4j.create(totalRows, data[0].columns());
        totalRows = 0;
        for (INDArray i : data) {
            if (i.size(0) == 1) {
                out.putRow(totalRows++, i);
                continue;
            }
            out.put(new INDArrayIndex[]{NDArrayIndex.interval(totalRows, totalRows + i.size(0)), NDArrayIndex.all()}, i);
            totalRows += i.size(0);
        }
        return out;
    }

    private static INDArray merge4dCnnData(INDArray[] data) {
        if (data.length == 1) {
            return data[0];
        }
        int[] outSize = Arrays.copyOf(data[0].shape(), 4);
        for (int i = 1; i < data.length; ++i) {
            outSize[0] = outSize[0] + data[i].size(0);
        }
        INDArray out = Nd4j.create(outSize);
        int examplesSoFar = 0;
        INDArrayIndex[] indexes = new INDArrayIndex[4];
        indexes[1] = NDArrayIndex.all();
        indexes[2] = NDArrayIndex.all();
        indexes[3] = NDArrayIndex.all();
        for (int i = 0; i < data.length; ++i) {
            int[] thisShape = data[i].shape();
            if (thisShape.length != 4) {
                throw new IllegalStateException("Cannot merge CNN data: first DataSet data has shape " + Arrays.toString(data[0].shape()) + ", " + i + "th example has shape " + Arrays.toString(thisShape));
            }
            for (int j = 1; j < 4; ++j) {
                if (outSize[j] == thisShape[j]) continue;
                throw new IllegalStateException("Cannot merge CNN data: first DataSet data has shape " + Arrays.toString(data[0].shape()) + ", " + i + "th example has shape " + Arrays.toString(thisShape));
            }
            int thisNumExamples = data[i].size(0);
            indexes[0] = NDArrayIndex.interval(examplesSoFar, examplesSoFar + thisNumExamples);
            out.put(indexes, data[i]);
            examplesSoFar += thisNumExamples;
        }
        return out;
    }

    private static INDArray[] mergeTimeSeries(INDArray[] data, INDArray[] mask) {
        int firstLength;
        if (data.length == 1) {
            return new INDArray[]{data[0], mask == null ? null : mask[0]};
        }
        int maxLength = firstLength = data[0].size(2);
        boolean lengthsDiffer = false;
        int totalExamples = 0;
        for (INDArray arr : data) {
            int thisLength = arr.size(2);
            maxLength = Math.max(maxLength, thisLength);
            if (thisLength != firstLength) {
                lengthsDiffer = true;
            }
            totalExamples += arr.size(0);
        }
        boolean needMask = mask != null || lengthsDiffer;
        int vectorSize = data[0].size(1);
        INDArray out = Nd4j.create(new int[]{totalExamples, vectorSize, maxLength}, 'f');
        INDArray outMask = needMask ? Nd4j.create(totalExamples, maxLength) : null;
        int rowCount = 0;
        if (!needMask) {
            INDArrayIndex[] indexes = new INDArrayIndex[3];
            indexes[1] = NDArrayIndex.all();
            indexes[2] = NDArrayIndex.all();
            for (INDArray arr : data) {
                int nEx = arr.size(0);
                indexes[0] = NDArrayIndex.interval(rowCount, rowCount + nEx);
                out.put(indexes, arr);
                rowCount += nEx;
            }
        } else {
            INDArrayIndex[] indexes = new INDArrayIndex[3];
            indexes[1] = NDArrayIndex.all();
            for (int i = 0; i < data.length; ++i) {
                INDArray arr = data[i];
                int nEx = arr.size(0);
                int thisLength = arr.size(2);
                indexes[0] = NDArrayIndex.interval(rowCount, rowCount + nEx);
                indexes[2] = NDArrayIndex.interval(0, thisLength);
                out.put(indexes, arr);
                if (mask != null && mask[i] != null) {
                    outMask.put(new INDArrayIndex[]{NDArrayIndex.interval(rowCount, rowCount + nEx), NDArrayIndex.interval(0, thisLength)}, mask[i]);
                } else {
                    outMask.get(NDArrayIndex.interval(rowCount, rowCount + nEx), NDArrayIndex.interval(0, thisLength)).assign(1.0);
                }
                rowCount += nEx;
            }
        }
        return new INDArray[]{out, outMask};
    }

    public static DataSet merge(List<DataSet> data) {
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Unable to merge empty dataset");
        }
        return DataSet.merge(data, false);
    }

    private static int totalExamples(Collection<DataSet> coll) {
        int count = 0;
        for (DataSet d : coll) {
            count += d.numExamples();
        }
        return count;
    }

    @Override
    public org.nd4j.linalg.dataset.api.DataSet getRange(int from, int to) {
        if (this.hasMaskArrays()) {
            INDArray featureMaskHere = this.featuresMask != null ? this.featuresMask.get(NDArrayIndex.interval(from, to)) : null;
            INDArray labelMaskHere = this.labelsMask != null ? this.labelsMask.get(NDArrayIndex.interval(from, to)) : null;
            return new DataSet(this.features.get(NDArrayIndex.interval(from, to)), this.labels.get(NDArrayIndex.interval(from, to)), featureMaskHere, labelMaskHere);
        }
        return new DataSet(this.features.get(NDArrayIndex.interval(from, to)), this.labels.get(NDArrayIndex.interval(from, to)));
    }

    @Override
    public void load(InputStream from) {
        try {
            BufferedInputStream bis = new BufferedInputStream(from);
            DataInputStream dis = new DataInputStream(bis);
            byte included = dis.readByte();
            boolean hasFeatures = (included & 1) != 0;
            boolean hasLabels = (included & 2) != 0;
            boolean hasLabelsSameAsFeatures = (included & 4) != 0;
            boolean hasFeaturesMask = (included & 8) != 0;
            boolean hasLabelsMask = (included & 0x10) != 0;
            INDArray iNDArray = this.features = hasFeatures ? Nd4j.read(dis) : null;
            this.labels = hasLabels ? Nd4j.read(dis) : (hasLabelsSameAsFeatures ? this.features : null);
            this.featuresMask = hasFeaturesMask ? Nd4j.read(dis) : null;
            this.labelsMask = hasLabelsMask ? Nd4j.read(dis) : null;
            dis.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    public void load(File from) {
        try {
            this.load(new FileInputStream(from));
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void save(OutputStream to) {
        byte included = 0;
        if (this.features != null) {
            included = (byte)(included | 1);
        }
        if (this.labels != null) {
            included = this.labels == this.features ? (byte)(included | 4) : (byte)(included | 2);
        }
        if (this.featuresMask != null) {
            included = (byte)(included | 8);
        }
        if (this.labelsMask != null) {
            included = (byte)(included | 0x10);
        }
        try {
            BufferedOutputStream bos = new BufferedOutputStream(to);
            DataOutputStream dos = new DataOutputStream(bos);
            dos.writeByte(included);
            if (this.features != null) {
                Nd4j.write(this.features, dos);
            }
            if (this.labels != null && this.labels != this.features) {
                Nd4j.write(this.labels, dos);
            }
            if (this.featuresMask != null) {
                Nd4j.write(this.featuresMask, dos);
            }
            if (this.labelsMask != null) {
                Nd4j.write(this.labelsMask, dos);
            }
            dos.flush();
            dos.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    public void save(File to) {
        try {
            this.save(new FileOutputStream(to, false));
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public DataSetIterator iterateWithMiniBatches() {
        return null;
    }

    @Override
    public String id() {
        return "";
    }

    @Override
    public INDArray getFeatures() {
        return this.features;
    }

    @Override
    public void setFeatures(INDArray features) {
        this.features = features;
    }

    @Override
    public Map<Integer, Double> labelCounts() {
        HashMap<Integer, Double> ret = new HashMap<Integer, Double>();
        if (this.labels == null) {
            return ret;
        }
        int nTensors = this.labels.tensorssAlongDimension(1);
        for (int i = 0; i < nTensors; ++i) {
            INDArray row = this.labels.tensorAlongDimension(i, 1);
            int maxIdx = Nd4j.getBlasWrapper().iamax(row);
            if (maxIdx < 0) {
                throw new IllegalStateException("Please check the iamax implementation for " + Nd4j.getBlasWrapper().getClass().getName());
            }
            if (ret.get(maxIdx) == null) {
                ret.put(maxIdx, 1.0);
                continue;
            }
            ret.put(maxIdx, (Double)ret.get(maxIdx) + 1.0);
        }
        return ret;
    }

    @Override
    public void apply(Condition condition, Function<Number, Number> function) {
        BooleanIndexing.applyWhere(this.getFeatureMatrix(), condition, function);
    }

    @Override
    public DataSet copy() {
        DataSet ret = new DataSet(this.getFeatures().dup(), this.getLabels().dup());
        ret.setColumnNames(this.getColumnNames());
        ret.setLabelNames(this.getLabelNames());
        return ret;
    }

    @Override
    public DataSet reshape(int rows, int cols) {
        DataSet ret = new DataSet(this.getFeatures().reshape(new int[]{rows, cols}), this.getLabels());
        return ret;
    }

    @Override
    public void multiplyBy(double num) {
        this.getFeatures().muli(Nd4j.scalar(num));
    }

    @Override
    public void divideBy(int num) {
        this.getFeatures().divi(Nd4j.scalar(num));
    }

    @Override
    public void shuffle() {
        long seed = System.currentTimeMillis();
        this.shuffle(seed);
    }

    public void shuffle(long seed) {
        int[] nonzeroDimsFeat = ArrayUtil.range((int)1, (int)this.getFeatures().rank());
        int[] nonzeroDimsLab = ArrayUtil.range((int)1, (int)this.getLabels().rank());
        Nd4j.shuffle(this.getFeatureMatrix(), new Random(seed), nonzeroDimsFeat);
        Nd4j.shuffle(this.getLabels(), new Random(seed), nonzeroDimsLab);
        if (this.getFeaturesMaskArray() != null) {
            Nd4j.shuffle(this.getFeaturesMaskArray(), new Random(seed), nonzeroDimsFeat);
        }
        if (this.getLabelsMaskArray() != null) {
            Nd4j.shuffle(this.getLabelsMaskArray(), new Random(seed), nonzeroDimsLab);
        }
    }

    @Override
    public void squishToRange(double min, double max) {
        for (int i = 0; i < this.getFeatures().length(); ++i) {
            double curr = (Double)this.getFeatures().getScalar(i).element();
            if (curr < min) {
                this.getFeatures().put(i, Nd4j.scalar(min));
                continue;
            }
            if (!(curr > max)) continue;
            this.getFeatures().put(i, Nd4j.scalar(max));
        }
    }

    @Override
    public void scaleMinAndMax(double min, double max) {
        FeatureUtil.scaleMinMax(min, max, this.getFeatureMatrix());
    }

    @Override
    public void scale() {
        FeatureUtil.scaleByMax(this.getFeatures());
    }

    @Override
    public void addFeatureVector(INDArray toAdd) {
        this.setFeatures(Nd4j.hstack(this.getFeatureMatrix(), toAdd));
    }

    @Override
    public void addFeatureVector(INDArray feature, int example) {
        this.getFeatures().putRow(example, feature);
    }

    @Override
    public void normalize() {
        NormalizerStandardize inClassPreProcessor = new NormalizerStandardize();
        inClassPreProcessor.fit(this);
        inClassPreProcessor.transform(this);
    }

    @Override
    public void binarize() {
        this.binarize(0.0);
    }

    @Override
    public void binarize(double cutoff) {
        INDArray linear = this.getFeatureMatrix().linearView();
        for (int i = 0; i < this.getFeatures().length(); ++i) {
            double curr = linear.getDouble(i);
            if (curr > cutoff) {
                this.getFeatures().putScalar(i, 1);
                continue;
            }
            this.getFeatures().putScalar(i, 0);
        }
    }

    @Override
    @Deprecated
    public void normalizeZeroMeanZeroUnitVariance() {
        INDArray columnMeans = this.getFeatures().mean(0);
        INDArray columnStds = this.getFeatureMatrix().std(0);
        this.setFeatures(this.getFeatures().subiRowVector(columnMeans));
        columnStds.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        this.setFeatures(this.getFeatures().diviRowVector(columnStds));
    }

    @Override
    public int numInputs() {
        return this.getFeatures().columns();
    }

    @Override
    public void validate() {
        if (this.getFeatures().size(0) != this.getLabels().size(0)) {
            throw new IllegalStateException("Invalid dataset");
        }
    }

    @Override
    public int outcome() {
        return Nd4j.getBlasWrapper().iamax(this.getLabels());
    }

    @Override
    public void setNewNumberOfLabels(int labels) {
        int examples = this.numExamples();
        INDArray newOutcomes = Nd4j.create(examples, labels);
        this.setLabels(newOutcomes);
    }

    @Override
    public void setOutcome(int example, int label) {
        if (example > this.numExamples()) {
            throw new IllegalArgumentException("No example at " + example);
        }
        if (label > this.numOutcomes() || label < 0) {
            throw new IllegalArgumentException("Illegal label");
        }
        INDArray outcome = FeatureUtil.toOutcomeVector(label, this.numOutcomes());
        this.getLabels().putRow(example, outcome);
    }

    @Override
    public DataSet get(int i) {
        if (i > this.numExamples() || i < 0) {
            throw new IllegalArgumentException("invalid example number");
        }
        if (i == 0 && this.numExamples() == 1) {
            return this;
        }
        if (this.getFeatureMatrix().rank() == 4) {
            INDArray slice = this.getFeatureMatrix().slice(i);
            return new DataSet(slice.reshape(ArrayUtil.combine((int[][])new int[][]{{1}, slice.shape()})), this.getLabels().slice(i));
        }
        return new DataSet(this.getFeatures().slice(i), this.getLabels().slice(i));
    }

    @Override
    public DataSet get(int[] i) {
        return new DataSet(this.getFeatures().getRows(i), this.getLabels().getRows(i));
    }

    @Override
    public List<DataSet> batchBy(int num) {
        ArrayList batched = Lists.newArrayList();
        for (List splitBatch : Lists.partition(this.asList(), (int)num)) {
            batched.add(DataSet.merge(splitBatch));
        }
        return batched;
    }

    @Override
    public DataSet filterBy(int[] labels) {
        List<DataSet> list = this.asList();
        ArrayList<DataSet> newList = new ArrayList<DataSet>();
        ArrayList<Integer> labelList = new ArrayList<Integer>();
        for (int i : labels) {
            labelList.add(i);
        }
        for (DataSet d : list) {
            int outcome = d.outcome();
            if (!labelList.contains(outcome)) continue;
            newList.add(d);
        }
        return DataSet.merge(newList);
    }

    @Override
    public void filterAndStrip(int[] labels) {
        int i;
        DataSet filtered = this.filterBy(labels);
        ArrayList<Integer> newLabels = new ArrayList<Integer>();
        HashMap<Integer, Integer> labelMap = new HashMap<Integer, Integer>();
        for (i = 0; i < labels.length; ++i) {
            labelMap.put(labels[i], i);
        }
        for (i = 0; i < filtered.numExamples(); ++i) {
            DataSet example = filtered.get(i);
            int o2 = example.outcome();
            Integer outcome = (Integer)labelMap.get(o2);
            newLabels.add(outcome);
        }
        INDArray newLabelMatrix = Nd4j.create(filtered.numExamples(), labels.length);
        if (newLabelMatrix.rows() != newLabels.size()) {
            throw new IllegalStateException("Inconsistent label sizes");
        }
        for (int i2 = 0; i2 < newLabelMatrix.rows(); ++i2) {
            Integer i22 = (Integer)newLabels.get(i2);
            if (i22 == null) {
                throw new IllegalStateException("Label not found on row " + i2);
            }
            INDArray newRow = FeatureUtil.toOutcomeVector(i22, labels.length);
            newLabelMatrix.putRow(i2, newRow);
        }
        this.setFeatures(filtered.getFeatures());
        this.setLabels(newLabelMatrix);
    }

    @Override
    public List<DataSet> dataSetBatches(int num) {
        List list = Lists.partition(this.asList(), (int)num);
        ArrayList<DataSet> ret = new ArrayList<DataSet>();
        for (List l : list) {
            ret.add(DataSet.merge(l));
        }
        return ret;
    }

    @Override
    public List<DataSet> sortAndBatchByNumLabels() {
        this.sortByLabel();
        return this.batchByNumLabels();
    }

    @Override
    public List<DataSet> batchByNumLabels() {
        return this.batchBy(this.numOutcomes());
    }

    @Override
    public List<DataSet> asList() {
        ArrayList<DataSet> list = new ArrayList<DataSet>(this.numExamples());
        int rank = this.getFeatures().rank();
        int labelsRank = this.getLabels().rank();
        for (int i = 0; i < this.numExamples(); ++i) {
            INDArray labelMaskHere;
            INDArray labelsHere;
            INDArray featureMaskHere;
            INDArray featuresHere;
            switch (rank) {
                case 2: {
                    featuresHere = this.getFeatures().get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all());
                    featureMaskHere = this.featuresMask != null ? this.featuresMask.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all()) : null;
                    break;
                }
                case 3: {
                    featuresHere = this.getFeatures().get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all());
                    featureMaskHere = this.featuresMask != null ? this.featuresMask.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all()) : null;
                    break;
                }
                case 4: {
                    featuresHere = this.getFeatures().get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
                    featureMaskHere = this.featuresMask != null ? this.featuresMask.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all()) : null;
                    break;
                }
                default: {
                    throw new IllegalStateException("Cannot convert to list: feature set rank must be in range 2 to 4 inclusive. First example labels shape: " + Arrays.toString(this.getFeatures().shape()));
                }
            }
            switch (labelsRank) {
                case 2: {
                    labelsHere = this.getLabels().get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all());
                    labelMaskHere = this.labelsMask != null ? this.labelsMask.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all()) : null;
                    break;
                }
                case 3: {
                    labelsHere = this.getLabels().get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all());
                    labelMaskHere = this.labelsMask != null ? this.labelsMask.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all()) : null;
                    break;
                }
                case 4: {
                    labelsHere = this.getLabels().get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
                    labelMaskHere = this.labelsMask != null ? this.labelsMask.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all()) : null;
                    break;
                }
                default: {
                    throw new IllegalStateException("Cannot convert to list: feature set rank must be in range 2 to 4 inclusive. First example labels shape: " + Arrays.toString(this.getFeatures().shape()));
                }
            }
            list.add(new DataSet(featuresHere, labelsHere, featureMaskHere, labelMaskHere));
        }
        return list;
    }

    @Override
    public SplitTestAndTrain splitTestAndTrain(int numHoldout, Random rng) {
        long seed = rng.nextLong();
        this.shuffle(seed);
        return this.splitTestAndTrain(numHoldout);
    }

    @Override
    public SplitTestAndTrain splitTestAndTrain(int numHoldout) {
        int numExamples = this.numExamples();
        if (numExamples <= 1) {
            throw new IllegalStateException("Cannot split DataSet with <= 1 rows (data set has " + numExamples + " example)");
        }
        if (numHoldout >= numExamples) {
            throw new IllegalArgumentException("Unable to split on size equal or larger than the number of rows (# numExamples=" + numExamples + ", numHoldout=" + numHoldout + ")");
        }
        DataSet first = new DataSet(this.getFeatureMatrix().get(NDArrayIndex.interval(0, numHoldout), NDArrayIndex.all()), this.getLabels().get(NDArrayIndex.interval(0, numHoldout), NDArrayIndex.all()));
        DataSet second = new DataSet(this.getFeatureMatrix().get(NDArrayIndex.interval(numHoldout, this.numExamples()), NDArrayIndex.all()), this.getLabels().get(NDArrayIndex.interval(numHoldout, numExamples), NDArrayIndex.all()));
        return new SplitTestAndTrain(first, second);
    }

    @Override
    public INDArray getLabels() {
        return this.labels;
    }

    @Override
    public String getLabelName(int idx) {
        return this.labelNames.get(idx);
    }

    @Override
    public List<String> getLabelNames(INDArray idxs) {
        ArrayList<String> ret = new ArrayList<String>();
        for (int i = 0; i < idxs.length(); ++i) {
            ret.add(i, this.getLabelName(i));
        }
        return ret;
    }

    @Override
    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    @Override
    public INDArray getFeatureMatrix() {
        return this.getFeatures();
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public void sortByLabel() {
        void var6_11;
        Queue q;
        HashMap<Integer, ArrayDeque<DataSet>> map = new HashMap<Integer, ArrayDeque<DataSet>>();
        List<DataSet> data = this.asList();
        int numLabels = this.numOutcomes();
        int examples = this.numExamples();
        for (DataSet dataSet : data) {
            int label = dataSet.outcome();
            q = (ArrayDeque<DataSet>)map.get(label);
            if (q == null) {
                q = new ArrayDeque<DataSet>();
                map.put(label, (ArrayDeque<DataSet>)q);
            }
            q.add(dataSet);
        }
        for (Map.Entry entry : map.entrySet()) {
            log.info("Label " + entry + " has " + ((Queue)entry.getValue()).size() + " elements");
        }
        boolean optimal = true;
        boolean bl = false;
        while (var6_11 < examples) {
            if (optimal) {
                for (int j = 0; j < numLabels; ++j) {
                    q = (Queue)map.get(j);
                    if (q == null) {
                        optimal = false;
                    } else {
                        DataSet next = (DataSet)q.poll();
                        if (next != null) {
                            this.addRow(next, (int)var6_11);
                            ++var6_11;
                            continue;
                        }
                        optimal = false;
                    }
                    break;
                }
            } else {
                DataSet add = null;
                for (Queue q2 : map.values()) {
                    if (q2.isEmpty()) continue;
                    add = (DataSet)q2.poll();
                    break;
                }
                this.addRow(add, (int)var6_11);
            }
            ++var6_11;
        }
    }

    @Override
    public void addRow(DataSet d, int i) {
        if (i > this.numExamples() || d == null) {
            throw new IllegalArgumentException("Invalid index for adding a row");
        }
        this.getFeatures().putRow(i, d.getFeatures());
        this.getLabels().putRow(i, d.getLabels());
    }

    private int getLabel(DataSet data) {
        Float f = Float.valueOf(data.getLabels().maxNumber().floatValue());
        return f.intValue();
    }

    @Override
    public INDArray exampleSums() {
        return this.getFeatures().sum(1);
    }

    @Override
    public INDArray exampleMaxs() {
        return this.getFeatures().max(1);
    }

    @Override
    public INDArray exampleMeans() {
        return this.getFeatures().mean(1);
    }

    @Override
    public DataSet sample(int numSamples) {
        return this.sample(numSamples, Nd4j.getRandom());
    }

    @Override
    public DataSet sample(int numSamples, org.nd4j.linalg.api.rng.Random rng) {
        return this.sample(numSamples, rng, false);
    }

    @Override
    public DataSet sample(int numSamples, boolean withReplacement) {
        return this.sample(numSamples, Nd4j.getRandom(), withReplacement);
    }

    @Override
    public DataSet sample(int numSamples, org.nd4j.linalg.api.rng.Random rng, boolean withReplacement) {
        INDArray examples = Nd4j.create(numSamples, this.getFeatures().columns());
        INDArray outcomes = Nd4j.create(numSamples, this.numOutcomes());
        HashSet added = new HashSet();
        for (int i = 0; i < numSamples; ++i) {
            int picked = rng.nextInt(this.numExamples());
            if (!withReplacement) {
                while (added.contains(picked)) {
                    picked = rng.nextInt(this.numExamples());
                }
            }
            examples.putRow(i, this.get(picked).getFeatures());
            outcomes.putRow(i, this.get(picked).getLabels());
        }
        return new DataSet(examples, outcomes);
    }

    @Override
    public void roundToTheNearest(int roundTo) {
        for (int i = 0; i < this.getFeatures().length(); ++i) {
            double curr = (Double)this.getFeatures().getScalar(i).element();
            this.getFeatures().put(i, Nd4j.scalar(MathUtils.roundDouble((double)curr, (int)roundTo)));
        }
    }

    @Override
    public int numOutcomes() {
        return this.getLabels().columns();
    }

    @Override
    public int numExamples() {
        return this.getLabels().size(0);
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        if (this.features != null && this.labels != null) {
            builder.append("===========INPUT===================\n").append(this.getFeatures().toString().replaceAll(";", "\n")).append("\n=================OUTPUT==================\n").append(this.getLabels().toString().replaceAll(";", "\n"));
            if (this.featuresMask != null) {
                builder.append("\n===========INPUT MASK===================\n").append(this.getFeaturesMaskArray().toString().replaceAll(";", "\n"));
            }
            if (this.labelsMask != null) {
                builder.append("\n===========OUTPUT MASK===================\n").append(this.getLabelsMaskArray().toString().replaceAll(";", "\n"));
            }
            return builder.toString();
        }
        log.info("Features or labels are null values");
        return "";
    }

    @Override
    @Deprecated
    public List<String> getLabelNames() {
        return this.labelNames;
    }

    @Override
    public List<String> getLabelNamesList() {
        return this.labelNames;
    }

    @Override
    public void setLabelNames(List<String> labelNames) {
        this.labelNames = labelNames;
    }

    @Override
    @Deprecated
    public List<String> getColumnNames() {
        return this.columnNames;
    }

    @Override
    @Deprecated
    public void setColumnNames(List<String> columnNames) {
        this.columnNames = columnNames;
    }

    @Override
    public SplitTestAndTrain splitTestAndTrain(double percentTrain) {
        int numPercent = (int)(percentTrain * (double)this.numExamples());
        if (numPercent <= 0) {
            numPercent = 1;
        }
        return this.splitTestAndTrain(numPercent);
    }

    @Override
    public Iterator<DataSet> iterator() {
        return this.asList().iterator();
    }

    @Override
    public INDArray getFeaturesMaskArray() {
        return this.featuresMask;
    }

    @Override
    public void setFeaturesMaskArray(INDArray featuresMask) {
        this.featuresMask = featuresMask;
    }

    @Override
    public INDArray getLabelsMaskArray() {
        return this.labelsMask;
    }

    @Override
    public void setLabelsMaskArray(INDArray labelsMask) {
        this.labelsMask = labelsMask;
    }

    @Override
    public boolean hasMaskArrays() {
        return this.labelsMask != null || this.featuresMask != null;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof DataSet)) {
            return false;
        }
        DataSet d = (DataSet)o;
        if (!DataSet.equalOrBothNull(this.features, d.features)) {
            return false;
        }
        if (!DataSet.equalOrBothNull(this.labels, d.labels)) {
            return false;
        }
        if (!DataSet.equalOrBothNull(this.featuresMask, d.featuresMask)) {
            return false;
        }
        return DataSet.equalOrBothNull(this.labelsMask, d.labelsMask);
    }

    private static boolean equalOrBothNull(INDArray first, INDArray second) {
        if (first == null && second == null) {
            return true;
        }
        if (first == null || second == null) {
            return false;
        }
        return first.equals(second);
    }

    public int hashCode() {
        int result = this.getFeatures() != null ? this.getFeatures().hashCode() : 0;
        result = 31 * result + (this.getLabels() != null ? this.getLabels().hashCode() : 0);
        result = 31 * result + (this.getFeaturesMaskArray() != null ? this.getFeaturesMaskArray().hashCode() : 0);
        result = 31 * result + (this.getLabelsMaskArray() != null ? this.getLabelsMaskArray().hashCode() : 0);
        return result;
    }
}

