/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset;

import java.util.List;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.NonNull;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.callbacks.DataSetCallback;
import org.nd4j.linalg.dataset.callbacks.DefaultCallback;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AsyncDataSetIterator
implements DataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(AsyncDataSetIterator.class);
    protected DataSetIterator backedIterator;
    protected DataSet terminator = new DataSet();
    protected DataSet nextElement = null;
    protected BlockingQueue<DataSet> buffer;
    protected AsyncPrefetchThread thread;
    protected AtomicBoolean shouldWork = new AtomicBoolean(true);
    protected volatile RuntimeException throwable = null;
    protected boolean useWorkspace = true;
    protected int prefetchSize;
    protected String workspaceId;
    protected Integer deviceId;
    protected AtomicBoolean hasDepleted = new AtomicBoolean(false);
    protected DataSetCallback callback;

    protected AsyncDataSetIterator() {
    }

    public AsyncDataSetIterator(DataSetIterator baseIterator) {
        this(baseIterator, 8);
    }

    public AsyncDataSetIterator(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue) {
        this(iterator, queueSize, queue, true);
    }

    public AsyncDataSetIterator(DataSetIterator baseIterator, int queueSize) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize));
    }

    public AsyncDataSetIterator(DataSetIterator baseIterator, int queueSize, boolean useWorkspace) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace);
    }

    public AsyncDataSetIterator(DataSetIterator baseIterator, int queueSize, boolean useWorkspace, Integer deviceId) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace, new DefaultCallback(), deviceId);
    }

    public AsyncDataSetIterator(DataSetIterator baseIterator, int queueSize, boolean useWorkspace, DataSetCallback callback) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace, callback);
    }

    public AsyncDataSetIterator(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace) {
        this(iterator, queueSize, queue, useWorkspace, new DefaultCallback());
    }

    public AsyncDataSetIterator(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace, DataSetCallback callback) {
        this(iterator, queueSize, queue, useWorkspace, callback, Nd4j.getAffinityManager().getDeviceForCurrentThread());
    }

    public AsyncDataSetIterator(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
        if (queueSize < 2) {
            queueSize = 2;
        }
        this.deviceId = deviceId;
        this.callback = callback;
        this.useWorkspace = useWorkspace;
        this.buffer = queue;
        this.prefetchSize = queueSize;
        this.backedIterator = iterator;
        this.workspaceId = "ADSI_ITER-" + UUID.randomUUID().toString();
        if (iterator.resetSupported() && !iterator.hasNext()) {
            this.backedIterator.reset();
        }
        this.thread = new AsyncPrefetchThread(this.buffer, iterator, this.terminator, null, deviceId);
        this.thread.setDaemon(true);
        this.thread.start();
    }

    @Override
    public DataSet next(int num) {
        throw new UnsupportedOperationException();
    }

    @Override
    public int inputColumns() {
        return this.backedIterator.inputColumns();
    }

    @Override
    public int totalOutcomes() {
        return this.backedIterator.totalOutcomes();
    }

    @Override
    public boolean resetSupported() {
        return this.backedIterator.resetSupported();
    }

    @Override
    public boolean asyncSupported() {
        return false;
    }

    protected void externalCall() {
    }

    @Override
    public void reset() {
        this.buffer.clear();
        if (this.thread != null) {
            this.thread.interrupt();
        }
        try {
            if (this.thread != null) {
                this.thread.join();
            }
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
        this.thread.shutdown();
        this.buffer.clear();
        this.backedIterator.reset();
        this.shouldWork.set(true);
        this.thread = new AsyncPrefetchThread(this.buffer, this.backedIterator, this.terminator, null, this.deviceId);
        this.thread.setDaemon(true);
        this.thread.start();
        this.hasDepleted.set(false);
        this.nextElement = null;
    }

    public void shutdown() {
        this.buffer.clear();
        if (this.thread != null) {
            this.thread.interrupt();
        }
        try {
            if (this.thread != null) {
                this.thread.join();
            }
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
        this.thread.shutdown();
        this.buffer.clear();
    }

    @Override
    public int batch() {
        return this.backedIterator.batch();
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.backedIterator.setPreProcessor(preProcessor);
    }

    @Override
    public DataSetPreProcessor getPreProcessor() {
        return this.backedIterator.getPreProcessor();
    }

    @Override
    public List<String> getLabels() {
        return this.backedIterator.getLabels();
    }

    @Override
    public boolean hasNext() {
        if (this.throwable != null) {
            throw this.throwable;
        }
        try {
            if (this.hasDepleted.get()) {
                return false;
            }
            if (this.nextElement != null && this.nextElement != this.terminator) {
                return true;
            }
            if (this.nextElement == this.terminator) {
                return false;
            }
            this.nextElement = this.buffer.take();
            if (this.nextElement == this.terminator) {
                this.hasDepleted.set(true);
                return false;
            }
            return true;
        }
        catch (Exception e) {
            log.error("Premature end of loop!");
            throw new RuntimeException(e);
        }
    }

    @Override
    public DataSet next() {
        if (this.throwable != null) {
            throw this.throwable;
        }
        if (this.hasDepleted.get()) {
            return null;
        }
        DataSet temp = this.nextElement;
        this.nextElement = null;
        return temp;
    }

    @Override
    public void remove() {
    }

    protected class AsyncPrefetchThread
    extends Thread
    implements Runnable {
        private BlockingQueue<DataSet> queue;
        private DataSetIterator iterator;
        private DataSet terminator;
        private boolean isShutdown = false;
        private WorkspaceConfiguration configuration;
        private MemoryWorkspace workspace;
        private final int deviceId;

        protected AsyncPrefetchThread(@NonNull BlockingQueue<DataSet> queue, @NonNull DataSetIterator iterator, DataSet terminator, MemoryWorkspace workspace, int deviceId) {
            this.configuration = WorkspaceConfiguration.builder().minSize(0xA00000L).overallocationLimit(AsyncDataSetIterator.this.prefetchSize + 2).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyLearning(LearningPolicy.FIRST_LOOP).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).build();
            if (queue == null) {
                throw new NullPointerException("queue is marked non-null but is null");
            }
            if (iterator == null) {
                throw new NullPointerException("iterator is marked non-null but is null");
            }
            if (terminator == null) {
                throw new NullPointerException("terminator is marked non-null but is null");
            }
            this.queue = queue;
            this.iterator = iterator;
            this.terminator = terminator;
            this.deviceId = deviceId;
            this.setDaemon(true);
            this.setName("ADSI prefetch thread");
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         * Unable to fully structure code
         */
        @Override
        public void run() {
            Nd4j.getAffinityManager().unsafeSetDevice(this.deviceId);
            AsyncDataSetIterator.this.externalCall();
            try {
                if (AsyncDataSetIterator.this.useWorkspace) {
                    this.workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.configuration, AsyncDataSetIterator.this.workspaceId);
                }
                while (this.iterator.hasNext() && AsyncDataSetIterator.this.shouldWork.get()) {
                    smth = null;
                    if (AsyncDataSetIterator.this.useWorkspace) {
                        ws = this.workspace.notifyScopeEntered();
                        var3_7 = null;
                        try {
                            smth = (DataSet)this.iterator.next();
                            if (AsyncDataSetIterator.this.callback == null) ** GOTO lbl33
                            AsyncDataSetIterator.this.callback.call((org.nd4j.linalg.dataset.api.DataSet)smth);
                        }
                        catch (Throwable var4_9) {
                            var3_7 = var4_9;
                            throw var4_9;
                        }
                        finally {
                            if (ws != null) {
                                if (var3_7 != null) {
                                    try {
                                        ws.close();
                                    }
                                    catch (Throwable var4_8) {
                                        var3_7.addSuppressed(var4_8);
                                    }
                                } else {
                                    ws.close();
                                }
                            }
                        }
                    } else {
                        smth = (DataSet)this.iterator.next();
                        if (AsyncDataSetIterator.this.callback != null) {
                            AsyncDataSetIterator.this.callback.call((org.nd4j.linalg.dataset.api.DataSet)smth);
                        }
                    }
lbl33:
                    // 5 sources

                    Nd4j.getExecutioner().commit();
                    if (smth == null) continue;
                    this.queue.put((DataSet)smth);
                }
                this.queue.put(this.terminator);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                AsyncDataSetIterator.this.shouldWork.set(false);
            }
            catch (RuntimeException e) {
                AsyncDataSetIterator.this.throwable = e;
                throw new RuntimeException(e);
            }
            catch (Exception e) {
                AsyncDataSetIterator.this.throwable = new RuntimeException(e);
                throw new RuntimeException(e);
            }
            finally {
                e = this;
                synchronized (e) {
                    this.isShutdown = true;
                    this.notifyAll();
                }
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void shutdown() {
            AsyncPrefetchThread asyncPrefetchThread = this;
            synchronized (asyncPrefetchThread) {
                while (!this.isShutdown) {
                    try {
                        this.wait();
                    }
                    catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                        throw new RuntimeException(e);
                    }
                }
            }
            if (this.workspace != null) {
                log.debug("Manually destroying ADSI workspace");
                this.workspace.destroyWorkspace(true);
            }
        }
    }
}

