/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets;

import com.google.common.collect.Lists;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.fetchers.MnistDataFetcher;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.Persistable;
import org.deeplearning4j.util.MathUtils;
import org.jblas.DoubleMatrix;
import org.jblas.SimpleBlas;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataSet
extends Pair<DoubleMatrix, DoubleMatrix>
implements Persistable,
Iterable<DataSet> {
    private static final long serialVersionUID = 1935520764586513365L;
    private static Logger log = LoggerFactory.getLogger(DataSet.class);

    public DataSet(Pair<DoubleMatrix, DoubleMatrix> pair) {
        this(pair.getFirst(), pair.getSecond());
    }

    public DataSet(DoubleMatrix first, DoubleMatrix second) {
        super(first, second);
        if (first.rows != second.rows) {
            throw new IllegalStateException("Invalid data set; first and second do not have equal rows. First was " + first.rows + " second was " + second.rows);
        }
    }

    public DataSetIterator iterator(int batches) {
        List<DataSet> list = this.dataSetBatches(batches);
        return new ListDataSetIterator(list);
    }

    public DataSet copy() {
        return new DataSet((DoubleMatrix)this.getFirst(), (DoubleMatrix)this.getSecond());
    }

    public static DataSet empty() {
        return new DataSet(DoubleMatrix.zeros((int)1), DoubleMatrix.zeros((int)1));
    }

    public static DataSet merge(List<DataSet> data) {
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Unable to merge empty dataset");
        }
        DataSet first = data.iterator().next();
        DoubleMatrix in = new DoubleMatrix(data.size(), ((DoubleMatrix)first.getFirst()).columns);
        DoubleMatrix out = new DoubleMatrix(data.size(), ((DoubleMatrix)first.getSecond()).columns);
        for (int i = 0; i < data.size(); ++i) {
            in.putRow(i, (DoubleMatrix)data.get(i).getFirst());
            out.putRow(i, (DoubleMatrix)data.get(i).getSecond());
        }
        return new DataSet(in, out);
    }

    public int numInputs() {
        return ((DoubleMatrix)this.getFirst()).columns;
    }

    public void validate() {
        if (((DoubleMatrix)this.getFirst()).rows != ((DoubleMatrix)this.getSecond()).rows) {
            throw new IllegalStateException("Invalid dataset");
        }
    }

    public int outcome() {
        if (this.numExamples() > 1) {
            throw new IllegalStateException("Unable to derive outcome for dataset greater than one row");
        }
        return SimpleBlas.iamax((DoubleMatrix)((DoubleMatrix)this.getSecond()));
    }

    public DataSet get(int i) {
        return new DataSet(((DoubleMatrix)this.getFirst()).getRow(i), ((DoubleMatrix)this.getSecond()).getRow(i));
    }

    public List<List<DataSet>> batchBy(int num) {
        return Lists.partition(this.asList(), (int)num);
    }

    public List<DataSet> dataSetBatches(int num) {
        List list = Lists.partition(this.asList(), (int)num);
        ArrayList<DataSet> ret = new ArrayList<DataSet>();
        for (List l : list) {
            ret.add(DataSet.merge(l));
        }
        return ret;
    }

    public List<List<DataSet>> sortAndBatchByNumLabels() {
        this.sortByLabel();
        return Lists.partition(this.asList(), (int)this.numOutcomes());
    }

    public List<List<DataSet>> batchByNumLabels() {
        return Lists.partition(this.asList(), (int)this.numOutcomes());
    }

    public List<DataSet> asList() {
        ArrayList<DataSet> list = new ArrayList<DataSet>(this.numExamples());
        for (int i = 0; i < this.numExamples(); ++i) {
            list.add(new DataSet(((DoubleMatrix)this.getFirst()).getRow(i), ((DoubleMatrix)this.getSecond()).getRow(i)));
        }
        return list;
    }

    public Pair<DataSet, DataSet> splitTestAndTrain(int numHoldout) {
        if (numHoldout >= this.numExamples()) {
            throw new IllegalArgumentException("Unable to split on size larger than the number of rows");
        }
        List<DataSet> list = this.asList();
        Collections.rotate(list, 3);
        Collections.shuffle(list);
        ArrayList<List<DataSet>> partition = new ArrayList<List<DataSet>>();
        partition.add(list.subList(0, numHoldout));
        partition.add(list.subList(numHoldout, list.size()));
        DataSet train = DataSet.merge((List)partition.get(0));
        DataSet test = DataSet.merge((List)partition.get(1));
        return new Pair<DataSet, DataSet>(train, test);
    }

    public void sortByLabel() {
        Queue q;
        HashMap<Integer, ArrayDeque<DataSet>> map = new HashMap<Integer, ArrayDeque<DataSet>>();
        List<DataSet> data = this.asList();
        int numLabels = this.numOutcomes();
        int examples = this.numExamples();
        for (DataSet d : data) {
            int label = this.getLabel(d);
            q = (ArrayDeque<DataSet>)map.get(label);
            if (q == null) {
                q = new ArrayDeque<DataSet>();
                map.put(label, (ArrayDeque<DataSet>)q);
            }
            q.add(d);
        }
        for (Integer label : map.keySet()) {
            log.info("Label " + label + " has " + ((Queue)map.get(label)).size() + " elements");
        }
        boolean optimal = true;
        block2: for (int i = 0; i < examples; ++i) {
            if (optimal) {
                for (int j = 0; j < numLabels; ++j) {
                    q = (Queue)map.get(j);
                    DataSet next = (DataSet)q.poll();
                    if (next != null) {
                        this.addRow(next, i);
                        ++i;
                        continue;
                    }
                    optimal = false;
                    continue block2;
                }
                continue;
            }
            DataSet add = null;
            for (Queue q2 : map.values()) {
                if (q2.isEmpty()) continue;
                add = (DataSet)q2.poll();
                break;
            }
            this.addRow(add, i);
        }
    }

    public void addRow(DataSet d, int i) {
        if (i > this.numExamples() || d == null) {
            throw new IllegalArgumentException("Invalid index for adding a row");
        }
        ((DoubleMatrix)this.getFirst()).putRow(i, (DoubleMatrix)d.getFirst());
        ((DoubleMatrix)this.getSecond()).putRow(i, (DoubleMatrix)d.getSecond());
    }

    private int getLabel(DataSet data) {
        return SimpleBlas.iamax((DoubleMatrix)((DoubleMatrix)data.getSecond()));
    }

    public DoubleMatrix exampleSums() {
        return ((DoubleMatrix)this.getFirst()).columnSums();
    }

    public DoubleMatrix exampleMaxs() {
        return ((DoubleMatrix)this.getFirst()).columnMaxs();
    }

    public DoubleMatrix exampleMeans() {
        return ((DoubleMatrix)this.getFirst()).columnMeans();
    }

    public void saveTo(File file, boolean binary) throws IOException {
        if (file.exists()) {
            file.delete();
        }
        file.createNewFile();
        if (binary) {
            DataOutputStream bos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
            ((DoubleMatrix)this.getFirst()).out(bos);
            ((DoubleMatrix)this.getSecond()).out(bos);
            bos.flush();
            bos.close();
        } else {
            BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(file));
            for (int i = 0; i < this.numExamples(); ++i) {
                bos.write(((DoubleMatrix)this.getFirst()).getRow(i).toString("%.3f", "[", "]", ", ", ";").getBytes());
                bos.write("\t".getBytes());
                bos.write(((DoubleMatrix)this.getSecond()).getRow(i).toString("%.3f", "[", "]", ", ", ";").getBytes());
                bos.write("\n".getBytes());
            }
            bos.flush();
            bos.close();
        }
    }

    public static DataSet load(File path) throws IOException {
        DataInputStream bis = new DataInputStream(new BufferedInputStream(new FileInputStream(path)));
        DoubleMatrix x = new DoubleMatrix(1, 1);
        DoubleMatrix y = new DoubleMatrix(1, 1);
        x.in(bis);
        y.in(bis);
        bis.close();
        return new DataSet(x, y);
    }

    public DataSet sample(int numSamples) {
        return this.sample(numSamples, (RandomGenerator)new MersenneTwister(System.currentTimeMillis()));
    }

    public DataSet sample(int numSamples, RandomGenerator rng) {
        return this.sample(numSamples, rng, false);
    }

    public DataSet sample(int numSamples, boolean withReplacement) {
        return this.sample(numSamples, (RandomGenerator)new MersenneTwister(System.currentTimeMillis()), withReplacement);
    }

    public DataSet sample(int numSamples, RandomGenerator rng, boolean withReplacement) {
        if (numSamples >= this.numExamples()) {
            return this;
        }
        DoubleMatrix examples = new DoubleMatrix(numSamples, ((DoubleMatrix)this.getFirst()).columns);
        DoubleMatrix outcomes = new DoubleMatrix(numSamples, this.numOutcomes());
        HashSet added = new HashSet();
        for (int i = 0; i < numSamples; ++i) {
            int picked = rng.nextInt(this.numExamples());
            while (added.contains(picked)) {
                picked = rng.nextInt(this.numExamples());
            }
            examples.putRow(i, ((DoubleMatrix)this.getFirst()).getRow(i));
            outcomes.putRow(i, ((DoubleMatrix)this.getSecond()).getRow(i));
        }
        return new DataSet(examples, outcomes);
    }

    public void roundToTheNearest(int roundTo) {
        for (int i = 0; i < ((DoubleMatrix)this.getFirst()).length; ++i) {
            double curr = ((DoubleMatrix)this.getFirst()).get(i);
            ((DoubleMatrix)this.getFirst()).put(i, MathUtils.roundDouble(curr, roundTo));
        }
    }

    public int numOutcomes() {
        return ((DoubleMatrix)this.getSecond()).columns;
    }

    public int numExamples() {
        return ((DoubleMatrix)this.getFirst()).rows;
    }

    @Override
    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("===========INPUT===================\n").append(((DoubleMatrix)this.getFirst()).toString().replaceAll(";", "\n")).append("\n=================OUTPUT==================\n").append(((DoubleMatrix)this.getSecond()).toString().replaceAll(";", "\n"));
        return builder.toString();
    }

    public static void main(String[] args) throws IOException {
        MnistDataFetcher fetcher = new MnistDataFetcher();
        fetcher.fetch(100);
        DataSet write = new DataSet(fetcher.next());
        write.saveTo(new File(args[0]), false);
    }

    @Override
    public void write(OutputStream os) {
        DataOutputStream dos = new DataOutputStream(os);
        try {
            ((DoubleMatrix)this.getFirst()).out(dos);
            ((DoubleMatrix)this.getSecond()).out(dos);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void load(InputStream is) {
        DataInputStream dis = new DataInputStream(is);
        try {
            ((DoubleMatrix)this.getFirst()).in(dis);
            ((DoubleMatrix)this.getSecond()).in(dis);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public Iterator<DataSet> iterator() {
        return this.asList().iterator();
    }
}

