/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator;

import java.util.ConcurrentModificationException;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AsyncDataSetIterator
implements DataSetIterator {
    private final DataSetIterator baseIterator;
    private final BlockingQueue<DataSet> blockingQueue;
    private Thread thread;
    private IteratorRunnable runnable;
    protected static final Logger logger = LoggerFactory.getLogger(AsyncDataSetIterator.class);

    public AsyncDataSetIterator(DataSetIterator baseIterator) {
        this(baseIterator, 8);
    }

    public AsyncDataSetIterator(DataSetIterator baseIterator, int queueSize) {
        if (queueSize <= 0) {
            throw new IllegalArgumentException("Queue size must be > 0");
        }
        if (queueSize < 2) {
            queueSize = 2;
        }
        this.baseIterator = baseIterator;
        if (this.baseIterator.resetSupported()) {
            this.baseIterator.reset();
        }
        this.blockingQueue = new LinkedBlockingDeque<DataSet>(queueSize);
        this.runnable = new IteratorRunnable(baseIterator.hasNext());
        this.thread = this.runnable;
        Integer deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        Nd4j.getAffinityManager().attachThreadToDevice(this.thread, deviceId);
        this.thread.setDaemon(true);
        this.thread.start();
    }

    public DataSet next(int num) {
        throw new UnsupportedOperationException("Next(int) not supported for AsyncDataSetIterator");
    }

    public int totalExamples() {
        return this.baseIterator.totalExamples();
    }

    public int inputColumns() {
        return this.baseIterator.inputColumns();
    }

    public int totalOutcomes() {
        return this.baseIterator.totalOutcomes();
    }

    public boolean resetSupported() {
        return this.baseIterator.resetSupported();
    }

    public boolean asyncSupported() {
        return false;
    }

    public synchronized void reset() {
        if (!this.resetSupported()) {
            throw new UnsupportedOperationException("Cannot reset Async iterator wrapping iterator that does not support reset");
        }
        this.runnable.killRunnable = true;
        if (this.runnable.isAlive.get()) {
            this.thread.interrupt();
        }
        try {
            this.runnable.runCompletedSemaphore.tryAcquire(5L, TimeUnit.SECONDS);
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
        this.blockingQueue.clear();
        this.baseIterator.reset();
        this.runnable = new IteratorRunnable(this.baseIterator.hasNext());
        this.thread = this.runnable;
        Integer deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        Nd4j.getAffinityManager().attachThreadToDevice(this.thread, deviceId);
        this.thread.setDaemon(true);
        this.thread.start();
    }

    public int batch() {
        return this.baseIterator.batch();
    }

    public int cursor() {
        return this.baseIterator.cursor();
    }

    public int numExamples() {
        return this.baseIterator.numExamples();
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.baseIterator.setPreProcessor(preProcessor);
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.baseIterator.getPreProcessor();
    }

    public List<String> getLabels() {
        return this.baseIterator.getLabels();
    }

    public synchronized boolean hasNext() {
        if (!this.blockingQueue.isEmpty()) {
            return true;
        }
        if (this.runnable.isAlive.get()) {
            return this.runnable.hasLatch();
        }
        if (!this.runnable.killRunnable && this.runnable.exception != null) {
            throw this.runnable.exception;
        }
        return this.runnable.hasLatch();
    }

    public synchronized DataSet next() {
        if (!this.hasNext()) {
            throw new NoSuchElementException();
        }
        if (this.runnable.exception != null) {
            throw this.runnable.exception;
        }
        if (!this.blockingQueue.isEmpty()) {
            this.runnable.feeder.decrementAndGet();
            return (DataSet)this.blockingQueue.poll();
        }
        try {
            while (this.runnable.exception == null) {
                DataSet ds = this.blockingQueue.poll(2L, TimeUnit.SECONDS);
                if (ds != null) {
                    this.runnable.feeder.decrementAndGet();
                    return ds;
                }
                if (this.runnable.killRunnable) {
                    throw new ConcurrentModificationException("Reset while next() is waiting for element?");
                }
                if (this.runnable.isAlive.get() || !this.blockingQueue.isEmpty()) continue;
                if (this.runnable.exception != null) {
                    throw new RuntimeException("Exception thrown in base iterator", this.runnable.exception);
                }
                throw new IllegalStateException("Unexpected state occurred for AsyncDataSetIterator: runnable died or no data available");
            }
            throw this.runnable.exception;
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    public void shutdown() {
        if (this.thread != null && this.thread.isAlive()) {
            this.runnable.killRunnable = true;
            this.thread.interrupt();
            this.thread = null;
        }
    }

    public void remove() {
    }

    private class IteratorRunnable
    extends Thread
    implements Runnable {
        private volatile boolean killRunnable = false;
        private volatile AtomicBoolean isAlive = new AtomicBoolean(true);
        private volatile RuntimeException exception;
        private Semaphore runCompletedSemaphore = new Semaphore(0);
        private ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
        private AtomicLong feeder = new AtomicLong(0L);

        public IteratorRunnable(boolean hasNext) {
            this.isAlive.set(hasNext);
            this.setName("AsyncIterator thread");
            this.setDaemon(true);
        }

        public boolean hasLatch() {
            if (this.feeder.get() > 0L || !AsyncDataSetIterator.this.blockingQueue.isEmpty()) {
                return true;
            }
            try {
                boolean result;
                this.lock.readLock().lock();
                boolean bl = result = AsyncDataSetIterator.this.baseIterator.hasNext() || this.feeder.get() != 0L || !AsyncDataSetIterator.this.blockingQueue.isEmpty();
                if (!this.isAlive.get()) {
                    boolean bl2 = result;
                    return bl2;
                }
                while (this.isAlive.get()) {
                    result = this.feeder.get() != 0L || !AsyncDataSetIterator.this.blockingQueue.isEmpty() || AsyncDataSetIterator.this.baseIterator.hasNext();
                    if (!result) continue;
                    boolean bl3 = true;
                    return bl3;
                }
                boolean bl4 = result;
                return bl4;
            }
            finally {
                this.lock.readLock().unlock();
            }
        }

        @Override
        public void run() {
            try {
                while (!this.killRunnable && AsyncDataSetIterator.this.baseIterator.hasNext()) {
                    this.feeder.incrementAndGet();
                    this.lock.writeLock().lock();
                    DataSet ds = (DataSet)AsyncDataSetIterator.this.baseIterator.next();
                    if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                        ((GridExecutioner)Nd4j.getExecutioner()).flushQueueBlocking();
                    }
                    this.lock.writeLock().unlock();
                    AsyncDataSetIterator.this.blockingQueue.put(ds);
                }
                this.isAlive.set(false);
            }
            catch (InterruptedException e) {
                if (this.killRunnable) {
                    return;
                }
                this.exception = new RuntimeException("Runnable interrupted unexpectedly", e);
            }
            catch (RuntimeException e) {
                this.exception = e;
                if (this.lock.writeLock().isHeldByCurrentThread()) {
                    this.lock.writeLock().unlock();
                }
            }
            finally {
                this.isAlive.set(false);
                this.runCompletedSemaphore.release();
            }
        }
    }
}

