/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator.parallel;

import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.deeplearning4j.datasets.iterator.parallel.BaseParallelDataSetIterator;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.enums.InequalityHandling;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JointParallelDataSetIterator
extends BaseParallelDataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(JointParallelDataSetIterator.class);
    protected List<DataSetIterator> asyncIterators = new ArrayList<DataSetIterator>();
    protected boolean enforceSingleDevice;
    protected int bufferSizePerDevice;

    public JointParallelDataSetIterator(@NonNull List<DataSetIterator> iterators, boolean singleDeviceMode, int bufferSize, @NonNull InequalityHandling inequalityHandling) {
        super(iterators.size());
        if (iterators == null) {
            throw new NullPointerException("iterators is marked @NonNull but is null");
        }
        if (inequalityHandling == null) {
            throw new NullPointerException("inequalityHandling is marked @NonNull but is null");
        }
        this.enforceSingleDevice = singleDeviceMode;
        this.bufferSizePerDevice = bufferSize;
        this.numProducers = iterators.size();
        this.inequalityHandling = inequalityHandling;
        if (this.numProducers == 0) {
            throw new IllegalArgumentException("You can't start ParallelDataSetIterator without input data");
        }
        this.initializeIterators(iterators);
    }

    protected void initializeIterators(List<DataSetIterator> originals) {
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int currentDevice = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        if (originals.size() % numDevices != 0) {
            log.error("WARNING: number of splits doesn't match number of devices!");
        }
        int cnt = 0;
        for (DataSetIterator iterator : originals) {
            int cDev = cnt % numDevices;
            this.asyncIterators.add((DataSetIterator)new AsyncDataSetIterator(iterator, this.bufferSizePerDevice, true, Integer.valueOf(cDev)));
            ++cnt;
        }
    }

    @Override
    public boolean hasNextFor(int consumer) {
        if (consumer >= this.numProducers || consumer < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        return this.asyncIterators.get(consumer).hasNext();
    }

    @Override
    public DataSet nextFor(int consumer) {
        if (consumer >= this.numProducers || consumer < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        return (DataSet)this.asyncIterators.get(consumer).next();
    }

    @Override
    protected void reset(int consumer) {
        if (consumer >= this.numProducers || consumer < 0) {
            throw new ND4JIllegalStateException("Non-existent consumer was requested");
        }
        this.asyncIterators.get(consumer).reset();
    }

    public static class Builder {
        private List<DataSetIterator> iterators = new ArrayList<DataSetIterator>();
        private boolean enforceSingleDevice = true;
        private int bufferSize = 4;
        private InequalityHandling inequalityHandling;

        public Builder(@NonNull InequalityHandling inequalityHandling) {
            if (inequalityHandling == null) {
                throw new NullPointerException("inequalityHandling is marked @NonNull but is null");
            }
            this.inequalityHandling = inequalityHandling;
        }

        public Builder(@NonNull List<DataSetIterator> iterators, @NonNull InequalityHandling inequalityHandling) {
            if (iterators == null) {
                throw new NullPointerException("iterators is marked @NonNull but is null");
            }
            if (inequalityHandling == null) {
                throw new NullPointerException("inequalityHandling is marked @NonNull but is null");
            }
            this.inequalityHandling = inequalityHandling;
            for (DataSetIterator iterator : iterators) {
                this.addSourceIterator(iterator);
            }
        }

        public Builder addSourceIterator(@NonNull DataSetIterator iterator) {
            if (iterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            if (!iterator.asyncSupported()) {
                throw new IllegalArgumentException("Source iterators should support async mode");
            }
            if (this.hasIterator(iterator)) {
                throw new IllegalArgumentException("You can't put equal iterators into this joint iterator");
            }
            this.iterators.add(iterator);
            return this;
        }

        protected boolean hasIterator(DataSetIterator iterator) {
            for (DataSetIterator iter : this.iterators) {
                if (iter != iterator) continue;
                return true;
            }
            return false;
        }

        public Builder setBufferSizePerSplit(int bufferSize) {
            this.bufferSize = bufferSize;
            return this;
        }

        public Builder enforceSingleDevice(boolean reallyEnforce) {
            this.enforceSingleDevice = reallyEnforce;
            return this;
        }

        public JointParallelDataSetIterator build() {
            JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator(this.iterators, this.enforceSingleDevice, this.bufferSize, this.inequalityHandling);
            return jpdsi;
        }
    }
}

