/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.fetchers;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.datasets.fetchers.BaseDataFetcher;
import org.deeplearning4j.datasets.mnist.MnistManager;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

public class MnistDataFetcher
extends BaseDataFetcher {
    public static final int NUM_EXAMPLES = 60000;
    public static final int NUM_EXAMPLES_TEST = 10000;
    protected static final String TEMP_ROOT = System.getProperty("user.home");
    protected static final String MNIST_ROOT = TEMP_ROOT + File.separator + "MNIST" + File.separator;
    protected transient MnistManager man;
    protected boolean binarize = true;
    protected boolean train;
    protected int[] order;
    protected Random rng;
    protected boolean shuffle;

    public MnistDataFetcher(boolean binarize) throws IOException {
        this(binarize, true, true, System.currentTimeMillis());
    }

    public MnistDataFetcher(boolean binarize, boolean train, boolean shuffle, long rngSeed) throws IOException {
        String labels;
        String images;
        if (!this.mnistExists()) {
            new MnistFetcher().downloadAndUntar();
        }
        if (train) {
            images = MNIST_ROOT + "images-idx1-ubyte";
            labels = MNIST_ROOT + "labels-idx1-ubyte";
            this.totalExamples = 60000;
        } else {
            images = MNIST_ROOT + "t10k-images-idx3-ubyte";
            labels = MNIST_ROOT + "t10k-labels-idx1-ubyte";
            this.totalExamples = 10000;
        }
        try {
            this.man = new MnistManager(images, labels, train);
        }
        catch (Exception e) {
            FileUtils.deleteDirectory((File)new File(MNIST_ROOT));
            new MnistFetcher().downloadAndUntar();
            this.man = new MnistManager(images, labels, train);
        }
        this.numOutcomes = 10;
        this.binarize = binarize;
        this.cursor = 0;
        this.inputColumns = this.man.getImages().getEntryLength();
        this.train = train;
        this.shuffle = shuffle;
        this.order = train ? new int[60000] : new int[10000];
        for (int i = 0; i < this.order.length; ++i) {
            this.order[i] = i;
        }
        this.rng = new Random(rngSeed);
        this.reset();
    }

    private boolean mnistExists() {
        File f = new File(MNIST_ROOT, "images-idx1-ubyte");
        if (!f.exists()) {
            return false;
        }
        f = new File(MNIST_ROOT, "labels-idx1-ubyte");
        if (!f.exists()) {
            return false;
        }
        f = new File(MNIST_ROOT, "t10k-images-idx3-ubyte");
        if (!f.exists()) {
            return false;
        }
        f = new File(MNIST_ROOT, "t10k-labels-idx1-ubyte");
        return f.exists();
    }

    public MnistDataFetcher() throws IOException {
        this(true);
    }

    @Override
    public void fetch(int numExamples) {
        if (!this.hasMore()) {
            throw new IllegalStateException("Unable to getFromOrigin more; there are no more images");
        }
        ArrayList<DataSet> toConvert = new ArrayList<DataSet>(numExamples);
        int i = 0;
        while (i < numExamples && this.hasMore()) {
            byte[] img = this.man.readImageUnsafe(this.order[this.cursor]);
            INDArray in = Nd4j.create((int)1, (int)img.length);
            for (int j = 0; j < img.length; ++j) {
                in.putScalar(j, img[j] & 0xFF);
            }
            if (this.binarize) {
                for (int d = 0; d < in.length(); ++d) {
                    if (in.getDouble(d) > 30.0) {
                        in.putScalar(d, 1);
                        continue;
                    }
                    in.putScalar(d, 0);
                }
            } else {
                in.divi((Number)255);
            }
            INDArray out = this.createOutputVector(this.man.readLabel(this.order[this.cursor]));
            toConvert.add(new DataSet(in, out));
            ++i;
            ++this.cursor;
        }
        this.initializeCurrFromList(toConvert);
    }

    @Override
    public void reset() {
        this.cursor = 0;
        this.curr = null;
        if (this.shuffle) {
            MathUtils.shuffleArray(this.order, this.rng);
        }
    }

    @Override
    public DataSet next() {
        DataSet next = super.next();
        return next;
    }
}

