/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator.parallel;

import com.google.common.collect.Lists;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.IOFileFilter;
import org.apache.commons.io.filefilter.RegexFileFilter;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.FileSplitDataSetIterator;
import org.deeplearning4j.datasets.iterator.callbacks.FileCallback;
import org.deeplearning4j.datasets.iterator.parallel.BaseParallelDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.enums.InequalityHandling;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FileSplitParallelDataSetIterator
extends BaseParallelDataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(FileSplitParallelDataSetIterator.class);
    public static final String DEFAULT_PATTERN = "dataset-%d.bin";
    private String pattern;
    private int buffer;
    protected List<DataSetIterator> asyncIterators = new ArrayList<DataSetIterator>();

    public FileSplitParallelDataSetIterator(@NonNull File rootFolder, @NonNull String pattern, @NonNull FileCallback callback) {
        this(rootFolder, pattern, callback, Nd4j.getAffinityManager().getNumberOfDevices());
        if (rootFolder == null) {
            throw new NullPointerException("rootFolder");
        }
        if (pattern == null) {
            throw new NullPointerException("pattern");
        }
        if (callback == null) {
            throw new NullPointerException("callback");
        }
    }

    public FileSplitParallelDataSetIterator(@NonNull File rootFolder, @NonNull String pattern, @NonNull FileCallback callback, int numThreads) {
        this(rootFolder, pattern, callback, numThreads, InequalityHandling.STOP_EVERYONE);
        if (rootFolder == null) {
            throw new NullPointerException("rootFolder");
        }
        if (pattern == null) {
            throw new NullPointerException("pattern");
        }
        if (callback == null) {
            throw new NullPointerException("callback");
        }
    }

    public FileSplitParallelDataSetIterator(@NonNull File rootFolder, @NonNull String pattern, @NonNull FileCallback callback, int numThreads, @NonNull InequalityHandling inequalityHandling) {
        this(rootFolder, pattern, callback, numThreads, 2, inequalityHandling);
        if (rootFolder == null) {
            throw new NullPointerException("rootFolder");
        }
        if (pattern == null) {
            throw new NullPointerException("pattern");
        }
        if (callback == null) {
            throw new NullPointerException("callback");
        }
        if (inequalityHandling == null) {
            throw new NullPointerException("inequalityHandling");
        }
    }

    public FileSplitParallelDataSetIterator(@NonNull File rootFolder, @NonNull String pattern, @NonNull FileCallback callback, int numThreads, int bufferPerThread, @NonNull InequalityHandling inequalityHandling) {
        super(numThreads);
        if (rootFolder == null) {
            throw new NullPointerException("rootFolder");
        }
        if (pattern == null) {
            throw new NullPointerException("pattern");
        }
        if (callback == null) {
            throw new NullPointerException("callback");
        }
        if (inequalityHandling == null) {
            throw new NullPointerException("inequalityHandling");
        }
        if (!rootFolder.exists() || !rootFolder.isDirectory()) {
            throw new IllegalArgumentException("Root folder should point to existing folder");
        }
        this.pattern = pattern;
        this.inequalityHandling = inequalityHandling;
        this.buffer = bufferPerThread;
        String modifiedPattern = pattern.replaceAll("\\%d", ".*.");
        RegexFileFilter fileFilter = new RegexFileFilter(modifiedPattern);
        ArrayList files = new ArrayList(FileUtils.listFiles((File)rootFolder, (IOFileFilter)fileFilter, null));
        log.debug("Files found: {}; Producers: {}", (Object)files.size(), (Object)this.numProducers);
        if (files.isEmpty()) {
            throw new IllegalArgumentException("No suitable files were found");
        }
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int cnt = 0;
        for (List part : Lists.partition(files, (int)(files.size() / numThreads))) {
            if (cnt >= numThreads) break;
            int cDev = cnt % numDevices;
            this.asyncIterators.add(new AsyncDataSetIterator((DataSetIterator)new FileSplitDataSetIterator(part, callback), bufferPerThread, true, cDev));
            ++cnt;
        }
    }

    @Override
    public boolean hasNextFor(int consumer) {
        if (consumer >= this.numProducers || consumer < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        return this.asyncIterators.get(consumer).hasNext();
    }

    @Override
    public DataSet nextFor(int consumer) {
        if (consumer >= this.numProducers || consumer < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        return (DataSet)this.asyncIterators.get(consumer).next();
    }

    @Override
    protected void reset(int consumer) {
        if (consumer >= this.numProducers || consumer < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        this.asyncIterators.get(consumer).reset();
    }
}

