/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.Condition;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataSet
implements org.nd4j.linalg.dataset.api.DataSet {
    private static final long serialVersionUID = 1935520764586513365L;
    private static Logger log = LoggerFactory.getLogger(DataSet.class);
    private List<String> columnNames = new ArrayList<String>();
    private List<String> labelNames = new ArrayList<String>();
    private INDArray features;
    private INDArray labels;

    public DataSet() {
        this(Nd4j.zeros(new int[]{1}), Nd4j.zeros(new int[]{1}));
    }

    public DataSet(INDArray first, INDArray second) {
        if (first.rows() != second.rows()) {
            throw new IllegalStateException("Invalid data applyTransformToDestination; first and second do not have equal rows. First was " + first.rows() + " second was " + second.rows());
        }
        this.features = first;
        this.labels = second;
    }

    @Override
    public INDArray getFeatures() {
        return this.features;
    }

    @Override
    public void apply(Condition condition, Function<Number, Number> function) {
        BooleanIndexing.applyWhere(this.getFeatureMatrix(), condition, function);
    }

    @Override
    public void setFeatures(INDArray features) {
        this.features = features;
    }

    @Override
    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    @Override
    public DataSet copy() {
        return new DataSet(this.getFeatures().dup(), this.getLabels().dup());
    }

    public static DataSet empty() {
        return new DataSet(Nd4j.zeros(new int[]{1}), Nd4j.zeros(new int[]{1}));
    }

    public static DataSet merge(List<DataSet> data) {
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Unable to merge empty dataset");
        }
        DataSet first = data.get(0);
        int numExamples = DataSet.totalExamples(data);
        INDArray in = Nd4j.create(numExamples, first.getFeatures().columns());
        INDArray out = Nd4j.create(numExamples, first.getLabels().columns());
        int count = 0;
        for (int i = 0; i < data.size(); ++i) {
            DataSet d1 = data.get(i);
            for (int j = 0; j < d1.numExamples(); ++j) {
                DataSet example = d1.get(j);
                in.putRow(count, example.getFeatures());
                out.putRow(count, example.getLabels());
                ++count;
            }
        }
        return new DataSet(in, out);
    }

    @Override
    public DataSet reshape(int rows, int cols) {
        DataSet ret = new DataSet(this.getFeatures().reshape(new int[]{rows, cols}), this.getLabels());
        return ret;
    }

    @Override
    public void multiplyBy(double num) {
        this.getFeatures().muli(Nd4j.scalar(num));
    }

    @Override
    public void divideBy(int num) {
        this.getFeatures().divi(Nd4j.scalar(num));
    }

    @Override
    public void shuffle() {
        List<DataSet> list = this.asList();
        Collections.shuffle(list);
        DataSet ret = DataSet.merge(list);
        this.setFeatures(ret.getFeatures());
        this.setLabels(ret.getLabels());
    }

    @Override
    public void squishToRange(double min, double max) {
        for (int i = 0; i < this.getFeatures().length(); ++i) {
            double curr = (Double)this.getFeatures().getScalar(i).element();
            if (curr < min) {
                this.getFeatures().put(i, Nd4j.scalar(min));
                continue;
            }
            if (!(curr > max)) continue;
            this.getFeatures().put(i, Nd4j.scalar(max));
        }
    }

    @Override
    public void scale() {
        FeatureUtil.scaleByMax(this.getFeatures());
    }

    @Override
    public void addFeatureVector(INDArray toAdd) {
        this.setFeatures(Nd4j.hstack(new INDArray[0]));
    }

    @Override
    public void addFeatureVector(INDArray feature, int example) {
        this.getFeatures().putRow(example, Nd4j.hstack(new INDArray[0]));
    }

    @Override
    public void normalize() {
        FeatureUtil.normalizeMatrix(this.getFeatures());
    }

    @Override
    public void binarize() {
        this.binarize(0.0);
    }

    @Override
    public void binarize(double cutoff) {
        for (int i = 0; i < this.getFeatures().length(); ++i) {
            double curr = (Double)this.getFeatures().getScalar(i).element();
            if (curr > cutoff) {
                this.getFeatures().put(i, Nd4j.scalar(1.0f));
                continue;
            }
            this.getFeatures().put(i, Nd4j.scalar(0.0f));
        }
    }

    @Override
    public void normalizeZeroMeanZeroUnitVariance() {
        INDArray columnMeans = this.getFeatures().mean(0);
        INDArray columnStds = this.getFeatureMatrix().std(0);
        this.setFeatures(this.getFeatures().subiRowVector(columnMeans));
        columnStds.addi(Nd4j.scalar(1.0E-6));
        this.setFeatures(this.getFeatures().diviRowVector(columnStds));
    }

    private static int totalExamples(Collection<DataSet> coll) {
        int count = 0;
        for (DataSet d : coll) {
            count += d.numExamples();
        }
        return count;
    }

    @Override
    public int numInputs() {
        return this.getFeatures().columns();
    }

    @Override
    public void validate() {
        if (this.getFeatures().rows() != this.getLabels().rows()) {
            throw new IllegalStateException("Invalid dataset");
        }
    }

    @Override
    public int outcome() {
        if (this.numExamples() > 1) {
            throw new IllegalStateException("Unable to derive outcome for dataset greater than one row");
        }
        return Nd4j.getBlasWrapper().iamax(this.getLabels());
    }

    @Override
    public void setNewNumberOfLabels(int labels) {
        int examples = this.numExamples();
        INDArray newOutcomes = Nd4j.create(examples, labels);
        this.setLabels(newOutcomes);
    }

    @Override
    public void setOutcome(int example, int label) {
        if (example > this.numExamples()) {
            throw new IllegalArgumentException("No example at " + example);
        }
        if (label > this.numOutcomes() || label < 0) {
            throw new IllegalArgumentException("Illegal label");
        }
        INDArray outcome = FeatureUtil.toOutcomeVector(label, this.numOutcomes());
        this.getLabels().putRow(example, outcome);
    }

    @Override
    public DataSet get(int i) {
        if (i > this.numExamples() || i < 0) {
            throw new IllegalArgumentException("invalid example number");
        }
        return new DataSet(this.getFeatures().getRow(i), this.getLabels().getRow(i));
    }

    @Override
    public DataSet get(int[] i) {
        return new DataSet(this.getFeatures().getRows(i), this.getLabels().getRows(i));
    }

    @Override
    public List<List<DataSet>> batchBy(int num) {
        return Lists.partition(this.asList(), (int)num);
    }

    @Override
    public DataSet filterBy(int[] labels) {
        List<DataSet> list = this.asList();
        ArrayList<DataSet> newList = new ArrayList<DataSet>();
        ArrayList<Integer> labelList = new ArrayList<Integer>();
        for (int i : labels) {
            labelList.add(i);
        }
        Object object = list.iterator();
        while (object.hasNext()) {
            DataSet d = (DataSet)object.next();
            int outcome = d.outcome();
            if (!labelList.contains(outcome)) continue;
            newList.add(d);
        }
        return DataSet.merge(newList);
    }

    @Override
    public void filterAndStrip(int[] labels) {
        int i;
        DataSet filtered = this.filterBy(labels);
        ArrayList<Integer> newLabels = new ArrayList<Integer>();
        HashMap<Integer, Integer> labelMap = new HashMap<Integer, Integer>();
        for (i = 0; i < labels.length; ++i) {
            labelMap.put(labels[i], i);
        }
        for (i = 0; i < filtered.numExamples(); ++i) {
            int o2 = filtered.get(i).outcome();
            int outcome = (Integer)labelMap.get(o2);
            newLabels.add(outcome);
        }
        INDArray newLabelMatrix = Nd4j.create(filtered.numExamples(), labels.length);
        if (newLabelMatrix.rows() != newLabels.size()) {
            throw new IllegalStateException("Inconsistent label sizes");
        }
        for (int i2 = 0; i2 < newLabelMatrix.rows(); ++i2) {
            Integer i22 = (Integer)newLabels.get(i2);
            if (i22 == null) {
                throw new IllegalStateException("Label not found on row " + i2);
            }
            INDArray newRow = FeatureUtil.toOutcomeVector(i22, labels.length);
            newLabelMatrix.putRow(i2, newRow);
        }
        this.setFeatures(filtered.getFeatures());
        this.setLabels(newLabelMatrix);
    }

    @Override
    public List<DataSet> dataSetBatches(int num) {
        List list = Lists.partition(this.asList(), (int)num);
        ArrayList<DataSet> ret = new ArrayList<DataSet>();
        for (List l : list) {
            ret.add(DataSet.merge(l));
        }
        return ret;
    }

    @Override
    public List<List<DataSet>> sortAndBatchByNumLabels() {
        this.sortByLabel();
        return Lists.partition(this.asList(), (int)this.numOutcomes());
    }

    @Override
    public List<List<DataSet>> batchByNumLabels() {
        return Lists.partition(this.asList(), (int)this.numOutcomes());
    }

    @Override
    public List<DataSet> asList() {
        ArrayList<DataSet> list = new ArrayList<DataSet>(this.numExamples());
        for (int i = 0; i < this.numExamples(); ++i) {
            list.add(new DataSet(this.getFeatures().getRow(i), this.getLabels().getRow(i)));
        }
        return list;
    }

    @Override
    public SplitTestAndTrain splitTestAndTrain(int numHoldout) {
        if (numHoldout >= this.numExamples()) {
            throw new IllegalArgumentException("Unable to split on size larger than the number of rows");
        }
        List<DataSet> list = this.asList();
        Collections.rotate(list, 3);
        Collections.shuffle(list);
        ArrayList<List<DataSet>> partition = new ArrayList<List<DataSet>>();
        partition.add(list.subList(0, numHoldout));
        partition.add(list.subList(numHoldout, list.size()));
        DataSet train = DataSet.merge((List)partition.get(0));
        DataSet test = DataSet.merge((List)partition.get(1));
        return new SplitTestAndTrain(train, test);
    }

    @Override
    public INDArray getLabels() {
        return this.labels;
    }

    @Override
    public INDArray getFeatureMatrix() {
        return this.getFeatures();
    }

    @Override
    public void sortByLabel() {
        Queue q;
        HashMap<Integer, ArrayDeque<DataSet>> map = new HashMap<Integer, ArrayDeque<DataSet>>();
        List<DataSet> data = this.asList();
        int numLabels = this.numOutcomes();
        int examples = this.numExamples();
        for (DataSet d : data) {
            int label = d.outcome();
            q = (ArrayDeque<DataSet>)map.get(label);
            if (q == null) {
                q = new ArrayDeque<DataSet>();
                map.put(label, (ArrayDeque<DataSet>)q);
            }
            q.add(d);
        }
        for (Integer label : map.keySet()) {
            log.info("Label " + label + " has " + ((Queue)map.get(label)).size() + " elements");
        }
        boolean optimal = true;
        block2: for (int i = 0; i < examples; ++i) {
            if (optimal) {
                for (int j = 0; j < numLabels; ++j) {
                    q = (Queue)map.get(j);
                    if (q == null) {
                        optimal = false;
                        continue block2;
                    }
                    DataSet next = (DataSet)q.poll();
                    if (next != null) {
                        this.addRow(next, i);
                        ++i;
                        continue;
                    }
                    optimal = false;
                    continue block2;
                }
                continue;
            }
            DataSet add = null;
            for (Queue q2 : map.values()) {
                if (q2.isEmpty()) continue;
                add = (DataSet)q2.poll();
                break;
            }
            this.addRow(add, i);
        }
    }

    @Override
    public void addRow(DataSet d, int i) {
        if (i > this.numExamples() || d == null) {
            throw new IllegalArgumentException("Invalid index for adding a row");
        }
        this.getFeatures().putRow(i, d.getFeatures());
        this.getLabels().putRow(i, d.getLabels());
    }

    private int getLabel(DataSet data) {
        Float f = (Float)data.getLabels().max(Integer.MAX_VALUE).element();
        return f.intValue();
    }

    @Override
    public INDArray exampleSums() {
        return this.getFeatures().sum(1);
    }

    @Override
    public INDArray exampleMaxs() {
        return this.getFeatures().max(1);
    }

    @Override
    public INDArray exampleMeans() {
        return this.getFeatures().mean(1);
    }

    @Override
    public DataSet sample(int numSamples) {
        return this.sample(numSamples, (RandomGenerator)new MersenneTwister(System.currentTimeMillis()));
    }

    @Override
    public DataSet sample(int numSamples, RandomGenerator rng) {
        return this.sample(numSamples, rng, false);
    }

    @Override
    public DataSet sample(int numSamples, boolean withReplacement) {
        return this.sample(numSamples, (RandomGenerator)new MersenneTwister(System.currentTimeMillis()), withReplacement);
    }

    @Override
    public DataSet sample(int numSamples, RandomGenerator rng, boolean withReplacement) {
        if (numSamples >= this.numExamples()) {
            return this;
        }
        INDArray examples = Nd4j.create(numSamples, this.getFeatures().columns());
        INDArray outcomes = Nd4j.create(numSamples, this.numOutcomes());
        HashSet added = new HashSet();
        for (int i = 0; i < numSamples; ++i) {
            int picked = rng.nextInt(this.numExamples());
            if (!withReplacement) {
                while (added.contains(picked)) {
                    picked = rng.nextInt(this.numExamples());
                }
            }
            examples.putRow(i, this.get(picked).getFeatures());
            outcomes.putRow(i, this.get(picked).getLabels());
        }
        return new DataSet(examples, outcomes);
    }

    @Override
    public void roundToTheNearest(int roundTo) {
        for (int i = 0; i < this.getFeatures().length(); ++i) {
            double curr = (Double)this.getFeatures().getScalar(i).element();
            this.getFeatures().put(i, Nd4j.scalar(MathUtils.roundDouble(curr, roundTo)));
        }
    }

    @Override
    public int numOutcomes() {
        return this.getLabels().columns();
    }

    @Override
    public int numExamples() {
        return this.getFeatures().rows();
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("===========INPUT===================\n").append(this.getFeatures().toString().replaceAll(";", "\n")).append("\n=================OUTPUT==================\n").append(this.getLabels().toString().replaceAll(";", "\n"));
        return builder.toString();
    }

    @Override
    public List<String> getLabelNames() {
        return this.labelNames;
    }

    @Override
    public void setLabelNames(List<String> labelNames) {
        if (labelNames == null || labelNames.size() != this.numOutcomes()) {
            throw new IllegalArgumentException("Unable to applyTransformToDestination label names, does not match number of possible outcomes");
        }
        this.labelNames = labelNames;
    }

    @Override
    public List<String> getColumnNames() {
        return this.columnNames;
    }

    @Override
    public void setColumnNames(List<String> columnNames) {
        if (columnNames.size() != this.numInputs()) {
            throw new IllegalArgumentException("Column names don't match input");
        }
        this.columnNames = columnNames;
    }

    @Override
    public Iterator<DataSet> iterator() {
        return this.asList().iterator();
    }
}

