/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.beans.ConstructorProperties;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;

public class BalanceMinibatches {
    private DataSetIterator dataSetIterator;
    private int numLabels;
    private Map<Integer, List<File>> paths = Maps.newHashMap();
    private int miniBatchSize = -1;
    private File rootDir = new File("minibatches");
    private File rootSaveDir = new File("minibatchessave");
    private List<File> labelRootDirs = new ArrayList<File>();
    private DataNormalization dataNormalization;

    public void balance() {
        if (!this.rootDir.exists()) {
            this.rootDir.mkdirs();
        }
        if (!this.rootSaveDir.exists()) {
            this.rootDir.mkdirs();
        }
        if (this.paths == null) {
            this.paths = Maps.newHashMap();
        }
        if (this.labelRootDirs == null) {
            this.labelRootDirs = Lists.newArrayList();
        }
        for (int i = 0; i < this.numLabels; ++i) {
            this.paths.put(i, new ArrayList());
            this.labelRootDirs.add(new File(this.rootDir, String.valueOf(i)));
        }
        while (this.dataSetIterator.hasNext()) {
            DataSet next = (DataSet)this.dataSetIterator.next();
            if (this.miniBatchSize < 0) {
                this.miniBatchSize = next.numExamples();
            }
            for (int i = 0; i < next.numExamples(); ++i) {
                DataSet currExample = next.get(i);
                if (!this.labelRootDirs.get(currExample.outcome()).exists()) {
                    this.labelRootDirs.get(currExample.outcome()).mkdirs();
                }
                File example = new File(this.labelRootDirs.get(currExample.outcome()), String.valueOf(this.paths.get(currExample.outcome()).size()));
                currExample.save(example);
                this.paths.get(currExample.outcome()).add(example);
            }
        }
        int numsSaved = 0;
        while (!this.paths.isEmpty()) {
            ArrayList<DataSet> miniBatch = new ArrayList<DataSet>();
            while (miniBatch.size() < this.miniBatchSize && !this.paths.isEmpty()) {
                for (int i = 0; i < this.numLabels; ++i) {
                    if (this.paths.get(i) != null && !this.paths.get(i).isEmpty()) {
                        DataSet d = new DataSet();
                        d.load(this.paths.get(i).remove(0));
                        miniBatch.add(d);
                        continue;
                    }
                    this.paths.remove(i);
                }
            }
            if (!this.rootSaveDir.exists()) {
                this.rootSaveDir.mkdirs();
            }
            DataSet merge = DataSet.merge(miniBatch);
            if (this.dataNormalization != null) {
                this.dataNormalization.transform(merge);
            }
            merge.save(new File(this.rootSaveDir, String.format("dataset-%d.bin", numsSaved++)));
        }
    }

    public static BalanceMinibatchesBuilder builder() {
        return new BalanceMinibatchesBuilder();
    }

    @ConstructorProperties(value={"dataSetIterator", "numLabels", "paths", "miniBatchSize", "rootDir", "rootSaveDir", "labelRootDirs", "dataNormalization"})
    public BalanceMinibatches(DataSetIterator dataSetIterator, int numLabels, Map<Integer, List<File>> paths, int miniBatchSize, File rootDir, File rootSaveDir, List<File> labelRootDirs, DataNormalization dataNormalization) {
        this.dataSetIterator = dataSetIterator;
        this.numLabels = numLabels;
        this.paths = paths;
        this.miniBatchSize = miniBatchSize;
        this.rootDir = rootDir;
        this.rootSaveDir = rootSaveDir;
        this.labelRootDirs = labelRootDirs;
        this.dataNormalization = dataNormalization;
    }

    public DataSetIterator getDataSetIterator() {
        return this.dataSetIterator;
    }

    public int getNumLabels() {
        return this.numLabels;
    }

    public Map<Integer, List<File>> getPaths() {
        return this.paths;
    }

    public int getMiniBatchSize() {
        return this.miniBatchSize;
    }

    public File getRootDir() {
        return this.rootDir;
    }

    public File getRootSaveDir() {
        return this.rootSaveDir;
    }

    public List<File> getLabelRootDirs() {
        return this.labelRootDirs;
    }

    public DataNormalization getDataNormalization() {
        return this.dataNormalization;
    }

    public void setDataSetIterator(DataSetIterator dataSetIterator) {
        this.dataSetIterator = dataSetIterator;
    }

    public void setNumLabels(int numLabels) {
        this.numLabels = numLabels;
    }

    public void setPaths(Map<Integer, List<File>> paths) {
        this.paths = paths;
    }

    public void setMiniBatchSize(int miniBatchSize) {
        this.miniBatchSize = miniBatchSize;
    }

    public void setRootDir(File rootDir) {
        this.rootDir = rootDir;
    }

    public void setRootSaveDir(File rootSaveDir) {
        this.rootSaveDir = rootSaveDir;
    }

    public void setLabelRootDirs(List<File> labelRootDirs) {
        this.labelRootDirs = labelRootDirs;
    }

    public void setDataNormalization(DataNormalization dataNormalization) {
        this.dataNormalization = dataNormalization;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BalanceMinibatches)) {
            return false;
        }
        BalanceMinibatches other = (BalanceMinibatches)o;
        if (!other.canEqual(this)) {
            return false;
        }
        DataSetIterator this$dataSetIterator = this.getDataSetIterator();
        DataSetIterator other$dataSetIterator = other.getDataSetIterator();
        if (this$dataSetIterator == null ? other$dataSetIterator != null : !this$dataSetIterator.equals(other$dataSetIterator)) {
            return false;
        }
        if (this.getNumLabels() != other.getNumLabels()) {
            return false;
        }
        Map<Integer, List<File>> this$paths = this.getPaths();
        Map<Integer, List<File>> other$paths = other.getPaths();
        if (this$paths == null ? other$paths != null : !((Object)this$paths).equals(other$paths)) {
            return false;
        }
        if (this.getMiniBatchSize() != other.getMiniBatchSize()) {
            return false;
        }
        File this$rootDir = this.getRootDir();
        File other$rootDir = other.getRootDir();
        if (this$rootDir == null ? other$rootDir != null : !((Object)this$rootDir).equals(other$rootDir)) {
            return false;
        }
        File this$rootSaveDir = this.getRootSaveDir();
        File other$rootSaveDir = other.getRootSaveDir();
        if (this$rootSaveDir == null ? other$rootSaveDir != null : !((Object)this$rootSaveDir).equals(other$rootSaveDir)) {
            return false;
        }
        List<File> this$labelRootDirs = this.getLabelRootDirs();
        List<File> other$labelRootDirs = other.getLabelRootDirs();
        if (this$labelRootDirs == null ? other$labelRootDirs != null : !((Object)this$labelRootDirs).equals(other$labelRootDirs)) {
            return false;
        }
        DataNormalization this$dataNormalization = this.getDataNormalization();
        DataNormalization other$dataNormalization = other.getDataNormalization();
        return !(this$dataNormalization == null ? other$dataNormalization != null : !this$dataNormalization.equals(other$dataNormalization));
    }

    protected boolean canEqual(Object other) {
        return other instanceof BalanceMinibatches;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        DataSetIterator $dataSetIterator = this.getDataSetIterator();
        result = result * 59 + ($dataSetIterator == null ? 0 : $dataSetIterator.hashCode());
        result = result * 59 + this.getNumLabels();
        Map<Integer, List<File>> $paths = this.getPaths();
        result = result * 59 + ($paths == null ? 0 : ((Object)$paths).hashCode());
        result = result * 59 + this.getMiniBatchSize();
        File $rootDir = this.getRootDir();
        result = result * 59 + ($rootDir == null ? 0 : ((Object)$rootDir).hashCode());
        File $rootSaveDir = this.getRootSaveDir();
        result = result * 59 + ($rootSaveDir == null ? 0 : ((Object)$rootSaveDir).hashCode());
        List<File> $labelRootDirs = this.getLabelRootDirs();
        result = result * 59 + ($labelRootDirs == null ? 0 : ((Object)$labelRootDirs).hashCode());
        DataNormalization $dataNormalization = this.getDataNormalization();
        result = result * 59 + ($dataNormalization == null ? 0 : $dataNormalization.hashCode());
        return result;
    }

    public String toString() {
        return "BalanceMinibatches(dataSetIterator=" + this.getDataSetIterator() + ", numLabels=" + this.getNumLabels() + ", paths=" + this.getPaths() + ", miniBatchSize=" + this.getMiniBatchSize() + ", rootDir=" + this.getRootDir() + ", rootSaveDir=" + this.getRootSaveDir() + ", labelRootDirs=" + this.getLabelRootDirs() + ", dataNormalization=" + this.getDataNormalization() + ")";
    }

    public static class BalanceMinibatchesBuilder {
        private DataSetIterator dataSetIterator;
        private int numLabels;
        private Map<Integer, List<File>> paths;
        private int miniBatchSize;
        private File rootDir;
        private File rootSaveDir;
        private List<File> labelRootDirs;
        private DataNormalization dataNormalization;

        BalanceMinibatchesBuilder() {
        }

        public BalanceMinibatchesBuilder dataSetIterator(DataSetIterator dataSetIterator) {
            this.dataSetIterator = dataSetIterator;
            return this;
        }

        public BalanceMinibatchesBuilder numLabels(int numLabels) {
            this.numLabels = numLabels;
            return this;
        }

        public BalanceMinibatchesBuilder paths(Map<Integer, List<File>> paths) {
            this.paths = paths;
            return this;
        }

        public BalanceMinibatchesBuilder miniBatchSize(int miniBatchSize) {
            this.miniBatchSize = miniBatchSize;
            return this;
        }

        public BalanceMinibatchesBuilder rootDir(File rootDir) {
            this.rootDir = rootDir;
            return this;
        }

        public BalanceMinibatchesBuilder rootSaveDir(File rootSaveDir) {
            this.rootSaveDir = rootSaveDir;
            return this;
        }

        public BalanceMinibatchesBuilder labelRootDirs(List<File> labelRootDirs) {
            this.labelRootDirs = labelRootDirs;
            return this;
        }

        public BalanceMinibatchesBuilder dataNormalization(DataNormalization dataNormalization) {
            this.dataNormalization = dataNormalization;
            return this;
        }

        public BalanceMinibatches build() {
            return new BalanceMinibatches(this.dataSetIterator, this.numLabels, this.paths, this.miniBatchSize, this.rootDir, this.rootSaveDir, this.labelRootDirs, this.dataNormalization);
        }

        public String toString() {
            return "BalanceMinibatches.BalanceMinibatchesBuilder(dataSetIterator=" + this.dataSetIterator + ", numLabels=" + this.numLabels + ", paths=" + this.paths + ", miniBatchSize=" + this.miniBatchSize + ", rootDir=" + this.rootDir + ", rootSaveDir=" + this.rootSaveDir + ", labelRootDirs=" + this.labelRootDirs + ", dataNormalization=" + this.dataNormalization + ")";
        }
    }
}

