/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism.inference.observers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BatchedInferenceObservable
extends BasicInferenceObservable
implements InferenceObservable {
    private static final Logger log = LoggerFactory.getLogger(BatchedInferenceObservable.class);
    private List<INDArray[]> inputs = new ArrayList<INDArray[]>();
    private List<INDArray[]> inputMasks = new ArrayList<INDArray[]>();
    private List<INDArray[]> outputs = new ArrayList<INDArray[]>();
    private AtomicInteger counter = new AtomicInteger(0);
    private ThreadLocal<Integer> position = new ThreadLocal();
    private List<int[]> outputBatchInputArrays = new ArrayList<int[]>();
    private final Object locker = new Object();
    private ReentrantReadWriteLock realLocker = new ReentrantReadWriteLock();
    private AtomicBoolean isLocked = new AtomicBoolean(false);
    private AtomicBoolean isReadLocked = new AtomicBoolean(false);

    public BatchedInferenceObservable() {
        super(new INDArray[0]);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void addInput(INDArray[] input, INDArray[] inputMasks) {
        Object object = this.locker;
        synchronized (object) {
            this.inputs.add(input);
            this.inputMasks.add(inputMasks);
            this.position.set(this.counter.getAndIncrement());
            if (this.isReadLocked.get()) {
                this.realLocker.readLock().unlock();
            }
        }
    }

    @Override
    public List<Pair<INDArray[], INDArray[]>> getInputBatches() {
        this.realLocker.writeLock().lock();
        this.isLocked.set(true);
        this.outputBatchInputArrays.clear();
        if (this.counter.get() > 1) {
            int pos = 0;
            ArrayList<Pair<INDArray[], INDArray[]>> out = new ArrayList<Pair<INDArray[], INDArray[]>>();
            int numArrays = this.inputs.get(0).length;
            while (pos < this.inputs.size()) {
                int lastPossible = pos;
                int i = pos + 1;
                while (i < this.inputs.size() && BatchedInferenceObservable.canBatch(this.inputs.get(pos), this.inputs.get(i))) {
                    lastPossible = i++;
                }
                int countToMerge = lastPossible - pos + 1;
                INDArray[][] featuresToMerge = new INDArray[countToMerge][0];
                INDArray[][] fMasksToMerge = null;
                int fPos = 0;
                for (int i2 = pos; i2 <= lastPossible; ++i2) {
                    featuresToMerge[fPos] = this.inputs.get(i2);
                    if (this.inputMasks.get(i2) != null) {
                        if (fMasksToMerge == null) {
                            fMasksToMerge = new INDArray[countToMerge][0];
                            for (int j = 0; j < countToMerge; ++j) {
                                fMasksToMerge[j] = null;
                            }
                        }
                        fMasksToMerge[fPos] = this.inputMasks.get(i2);
                    }
                    ++fPos;
                }
                Pair merged = DataSetUtil.mergeFeatures((INDArray[][])featuresToMerge, fMasksToMerge);
                out.add((Pair<INDArray[], INDArray[]>)merged);
                this.outputBatchInputArrays.add(new int[]{pos, lastPossible});
                pos = lastPossible + 1;
            }
            this.realLocker.writeLock().unlock();
            return out;
        }
        this.outputBatchInputArrays.add(new int[]{0, 0});
        this.realLocker.writeLock().unlock();
        return Collections.singletonList(new Pair((Object)this.inputs.get(0), (Object)this.inputMasks.get(0)));
    }

    private static boolean canBatch(INDArray[] first, INDArray[] candidate) {
        for (int i = 0; i < first.length; ++i) {
            if (Arrays.equals(first[i].shape(), candidate[i].shape())) continue;
            return false;
        }
        return true;
    }

    @Override
    public void setOutputBatches(List<INDArray[]> output) {
        int countNumInputBatches = 0;
        for (int outBatchNum = 0; outBatchNum < output.size(); ++outBatchNum) {
            INDArray[] currBatchOutputs = output.get(outBatchNum);
            int[] inputBatchIdxs = this.outputBatchInputArrays.get(outBatchNum);
            int inputBatchCount = inputBatchIdxs[1] - inputBatchIdxs[0] + 1;
            for (int i = 0; i < inputBatchCount; ++i) {
                this.outputs.add(new INDArray[currBatchOutputs.length]);
            }
            int firstInputBatch = countNumInputBatches;
            for (int outputNumber = 0; outputNumber < currBatchOutputs.length; ++outputNumber) {
                INDArray[] split = this.splitExamples(currBatchOutputs[outputNumber], inputBatchIdxs[0], inputBatchIdxs[1]);
                int currentInputBatch = firstInputBatch;
                for (int inputInBatch = 0; inputInBatch < inputBatchCount; ++inputInBatch) {
                    this.outputs.get((int)currentInputBatch++)[outputNumber] = split[inputInBatch];
                    if (outputNumber != 0) continue;
                    ++countNumInputBatches;
                }
            }
        }
        this.setChanged();
        this.notifyObservers();
    }

    private INDArray[] splitExamples(INDArray netOutput, int firstInputComponent, int lastInputComponent) {
        int numSplits = lastInputComponent - firstInputComponent + 1;
        if (numSplits == 1) {
            return new INDArray[]{netOutput};
        }
        INDArray[] out = new INDArray[numSplits];
        INDArrayIndex[] indices = new INDArrayIndex[netOutput.rank()];
        for (int i = 1; i < indices.length; ++i) {
            indices[i] = NDArrayIndex.all();
        }
        int examplesSoFar = 0;
        for (int inNum = 0; inNum < numSplits; ++inNum) {
            long inSizeEx = this.inputs.get(firstInputComponent + inNum)[0].size(0);
            indices[0] = NDArrayIndex.interval((long)examplesSoFar, (long)((long)examplesSoFar + inSizeEx));
            out[inNum] = netOutput.get(indices);
            examplesSoFar = (int)((long)examplesSoFar + inSizeEx);
        }
        return out;
    }

    protected List<INDArray[]> getOutputs() {
        return this.outputs;
    }

    protected void setCounter(int value) {
        this.counter.set(value);
    }

    public void setPosition(int pos) {
        this.position.set(pos);
    }

    public int getCounter() {
        return this.counter.get();
    }

    public boolean isLocked() {
        boolean result;
        boolean lck = !this.realLocker.readLock().tryLock();
        boolean bl = result = lck || this.isLocked.get();
        if (!result) {
            this.isReadLocked.set(true);
        }
        return result;
    }

    @Override
    public INDArray[] getOutput() {
        this.checkOutputException();
        return this.outputs.get(this.position.get());
    }
}

