/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.AbstractLSTM;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.TimesOneMinus;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.OldMulOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LSTMHelpers {
    private static final Logger log = LoggerFactory.getLogger(LSTMHelpers.class);

    private LSTMHelpers() {
    }

    public static FwdPassReturn activateHelper(BaseLayer layer, NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input, INDArray recurrentWeights, INDArray originalInputWeights, INDArray biases, boolean training, INDArray originalPrevOutputActivations, INDArray originalPrevMemCellState, boolean forBackprop, boolean forwards, String inputWeightKey, INDArray maskArray, boolean hasPeepholeConnections, LSTMHelper helper, CacheMode cacheMode) {
        FwdPassReturn ret;
        if (input == null || input.length() == 0) {
            throw new IllegalArgumentException("Invalid input: not set or 0 length");
        }
        INDArray inputWeights = originalInputWeights;
        INDArray prevOutputActivations = originalPrevOutputActivations;
        boolean is2dInput = input.rank() < 3;
        int timeSeriesLength = is2dInput ? 1 : input.size(2);
        int hiddenLayerSize = recurrentWeights.size(0);
        int miniBatchSize = input.size(0);
        INDArray prevMemCellState = originalPrevMemCellState == null ? Nd4j.create((int[])new int[]{miniBatchSize, hiddenLayerSize}, (char)'f') : originalPrevMemCellState.dup('f');
        INDArray recurrentWeightsIFOG = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(4 * hiddenLayerSize))}).dup('f');
        INDArray wFFTranspose = null;
        INDArray wOOTranspose = null;
        INDArray wGGTranspose = null;
        if (hasPeepholeConnections) {
            wFFTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(4 * hiddenLayerSize), (int)(4 * hiddenLayerSize + 1))}).transpose();
            wOOTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(4 * hiddenLayerSize + 1), (int)(4 * hiddenLayerSize + 2))}).transpose();
            wGGTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(4 * hiddenLayerSize + 2), (int)(4 * hiddenLayerSize + 3))}).transpose();
            if (timeSeriesLength > 1 || forBackprop) {
                wFFTranspose = Shape.toMmulCompatible((INDArray)wFFTranspose);
                wOOTranspose = Shape.toMmulCompatible((INDArray)wOOTranspose);
                wGGTranspose = Shape.toMmulCompatible((INDArray)wGGTranspose);
            }
        }
        boolean sigmoidGates = gateActivationFn instanceof ActivationSigmoid;
        IActivation afn = ((org.deeplearning4j.nn.conf.layers.BaseLayer)layer.layerConf()).getActivationFn();
        INDArray outputActivations = null;
        FwdPassReturn toReturn = new FwdPassReturn();
        if (forBackprop) {
            toReturn.fwdPassOutputAsArrays = new INDArray[timeSeriesLength];
            toReturn.memCellState = new INDArray[timeSeriesLength];
            toReturn.memCellActivations = new INDArray[timeSeriesLength];
            toReturn.iz = new INDArray[timeSeriesLength];
            toReturn.ia = new INDArray[timeSeriesLength];
            toReturn.fa = new INDArray[timeSeriesLength];
            toReturn.oa = new INDArray[timeSeriesLength];
            toReturn.ga = new INDArray[timeSeriesLength];
            if (!sigmoidGates) {
                toReturn.fz = new INDArray[timeSeriesLength];
                toReturn.oz = new INDArray[timeSeriesLength];
                toReturn.gz = new INDArray[timeSeriesLength];
            }
            if (cacheMode != CacheMode.NONE) {
                try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeBorrowed();){
                    toReturn.fwdPassOutput = outputActivations = Nd4j.create((int[])new int[]{miniBatchSize, hiddenLayerSize, timeSeriesLength}, (char)'f');
                }
            }
        } else {
            toReturn.fwdPassOutput = outputActivations = Nd4j.create((int[])new int[]{miniBatchSize, hiddenLayerSize, timeSeriesLength}, (char)'f');
        }
        Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
        if (input.size(1) != inputWeights.size(0)) {
            throw new DL4JInvalidInputException("Received input with size(1) = " + input.size(1) + " (input array shape = " + Arrays.toString(input.shape()) + "); input.size(1) must match layer nIn size (nIn = " + inputWeights.size(0) + ")");
        }
        if (prevOutputActivations != null && prevOutputActivations.size(0) != input.size(0)) {
            throw new DL4JInvalidInputException("Previous activations (stored state) number of examples = " + prevOutputActivations.size(0) + " but input array number of examples = " + input.size(0) + ". Possible cause: using rnnTimeStep() without calling rnnClearPreviousState() between different sequences?");
        }
        if (prevOutputActivations == null) {
            prevOutputActivations = Nd4j.zeros((int[])new int[]{miniBatchSize, hiddenLayerSize});
        }
        if (helper != null && (ret = helper.activate(layer, conf, gateActivationFn, input, recurrentWeights, inputWeights, biases, training, prevOutputActivations, prevMemCellState, forBackprop, forwards, inputWeightKey, maskArray, hasPeepholeConnections)) != null) {
            return ret;
        }
        for (int iTimeIndex = 0; iTimeIndex < timeSeriesLength; ++iTimeIndex) {
            INDArray currHiddenUnitActivations;
            INDArray inputModMulInput;
            INDArray currentMemoryCellState;
            int time = iTimeIndex;
            if (!forwards) {
                time = timeSeriesLength - iTimeIndex - 1;
            }
            INDArray miniBatchData = is2dInput ? input : input.tensorAlongDimension(time, new int[]{1, 0});
            miniBatchData = Shape.toMmulCompatible((INDArray)miniBatchData);
            if (cacheMode != CacheMode.NONE) {
                Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeBorrowed();
            }
            INDArray ifogActivations = miniBatchData.mmul(inputWeights);
            if (cacheMode != CacheMode.NONE) {
                Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeLeft();
            }
            Nd4j.gemm((INDArray)prevOutputActivations, (INDArray)recurrentWeightsIFOG, (INDArray)ifogActivations, (boolean)false, (boolean)false, (double)1.0, (double)1.0);
            ifogActivations.addiRowVector(biases);
            INDArray inputActivations = ifogActivations.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)});
            if (forBackprop) {
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeBorrowed();
                }
                toReturn.iz[time] = inputActivations.dup('f');
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeLeft();
                }
            }
            ((org.deeplearning4j.nn.conf.layers.BaseLayer)layer.layerConf()).getActivationFn().getActivation(inputActivations, training);
            if (forBackprop) {
                toReturn.ia[time] = inputActivations;
            }
            INDArray forgetGateActivations = ifogActivations.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))});
            if (hasPeepholeConnections) {
                INDArray pmcellWFF = prevMemCellState.dup('f').muliRowVector(wFFTranspose);
                l1BLAS.axpy(pmcellWFF.length(), 1.0, pmcellWFF, forgetGateActivations);
            }
            if (forBackprop && !sigmoidGates) {
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeBorrowed();
                }
                toReturn.fz[time] = forgetGateActivations.dup('f');
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeLeft();
                }
            }
            gateActivationFn.getActivation(forgetGateActivations, training);
            if (forBackprop) {
                toReturn.fa[time] = forgetGateActivations;
            }
            INDArray inputModGateActivations = ifogActivations.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
            if (hasPeepholeConnections) {
                INDArray pmcellWGG = prevMemCellState.dup('f').muliRowVector(wGGTranspose);
                l1BLAS.axpy(pmcellWGG.length(), 1.0, pmcellWGG, inputModGateActivations);
            }
            if (forBackprop && !sigmoidGates) {
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeBorrowed();
                }
                toReturn.gz[time] = inputModGateActivations.dup('f');
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeLeft();
                }
            }
            gateActivationFn.getActivation(inputModGateActivations, training);
            if (forBackprop) {
                toReturn.ga[time] = inputModGateActivations;
            }
            if (forBackprop) {
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeBorrowed();
                }
                currentMemoryCellState = prevMemCellState.dup('f').muli(forgetGateActivations);
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeLeft();
                }
                inputModMulInput = inputModGateActivations.dup('f').muli(inputActivations);
            } else {
                currentMemoryCellState = forgetGateActivations.muli(prevMemCellState);
                inputModMulInput = inputModGateActivations.muli(inputActivations);
            }
            l1BLAS.axpy(currentMemoryCellState.length(), 1.0, inputModMulInput, currentMemoryCellState);
            INDArray outputGateActivations = ifogActivations.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))});
            if (hasPeepholeConnections) {
                INDArray pmcellWOO = currentMemoryCellState.dup('f').muliRowVector(wOOTranspose);
                l1BLAS.axpy(pmcellWOO.length(), 1.0, pmcellWOO, outputGateActivations);
            }
            if (forBackprop && !sigmoidGates) {
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeBorrowed();
                }
                toReturn.oz[time] = outputGateActivations.dup('f');
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeLeft();
                }
            }
            gateActivationFn.getActivation(outputGateActivations, training);
            if (forBackprop) {
                toReturn.oa[time] = outputGateActivations;
            }
            if (cacheMode != CacheMode.NONE) {
                Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeBorrowed();
            }
            INDArray currMemoryCellActivation = afn.getActivation(currentMemoryCellState.dup('f'), training);
            if (cacheMode != CacheMode.NONE) {
                Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeLeft();
            }
            if (forBackprop) {
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeBorrowed();
                }
                currHiddenUnitActivations = currMemoryCellActivation.dup('f').muli(outputGateActivations);
                if (cacheMode != CacheMode.NONE) {
                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_CACHE").notifyScopeLeft();
                }
            } else {
                currHiddenUnitActivations = currMemoryCellActivation.muli(outputGateActivations);
            }
            if (maskArray != null) {
                INDArray timeStepMaskColumn = maskArray.getColumn(time);
                currHiddenUnitActivations.muliColumnVector(timeStepMaskColumn);
                currentMemoryCellState.muliColumnVector(timeStepMaskColumn);
            }
            if (forBackprop) {
                toReturn.fwdPassOutputAsArrays[time] = currHiddenUnitActivations;
                toReturn.memCellState[time] = currentMemoryCellState;
                toReturn.memCellActivations[time] = currMemoryCellActivation;
                if (cacheMode != CacheMode.NONE) {
                    outputActivations.tensorAlongDimension(time, new int[]{1, 0}).assign(currHiddenUnitActivations);
                }
            } else {
                outputActivations.tensorAlongDimension(time, new int[]{1, 0}).assign(currHiddenUnitActivations);
            }
            prevOutputActivations = currHiddenUnitActivations;
            prevMemCellState = currentMemoryCellState;
            toReturn.lastAct = currHiddenUnitActivations;
            toReturn.lastMemCell = currentMemoryCellState;
        }
        toReturn.prevAct = originalPrevOutputActivations;
        toReturn.prevMemCell = originalPrevMemCellState;
        return toReturn;
    }

    public static Pair<Gradient, INDArray> backpropGradientHelper(NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input, INDArray recurrentWeights, INDArray inputWeights, INDArray epsilon, boolean truncatedBPTT, int tbpttBackwardLength, FwdPassReturn fwdPass, boolean forwards, String inputWeightKey, String recurrentWeightKey, String biasWeightKey, Map<String, INDArray> gradientViews, INDArray maskArray, boolean hasPeepholeConnections, LSTMHelper helper) {
        Pair<Gradient, INDArray> ret;
        int hiddenLayerSize = recurrentWeights.size(0);
        int prevLayerSize = inputWeights.size(0);
        int miniBatchSize = epsilon.size(0);
        boolean is2dInput = epsilon.rank() < 3;
        int timeSeriesLength = is2dInput ? 1 : epsilon.size(2);
        INDArray wFFTranspose = null;
        INDArray wOOTranspose = null;
        INDArray wGGTranspose = null;
        if (hasPeepholeConnections) {
            wFFTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize))}).transpose();
            wOOTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize + 1))}).transpose();
            wGGTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize + 2))}).transpose();
        }
        INDArray wIFOG = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(4 * hiddenLayerSize))});
        INDArray epsilonNext = Nd4j.create((int[])new int[]{miniBatchSize, prevLayerSize, timeSeriesLength}, (char)'f');
        INDArray nablaCellStateNext = null;
        INDArray deltaifogNext = Nd4j.create((int[])new int[]{miniBatchSize, 4 * hiddenLayerSize}, (char)'f');
        INDArray deltaiNext = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)});
        INDArray deltafNext = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))});
        INDArray deltaoNext = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))});
        INDArray deltagNext = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
        Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
        int endIdx = 0;
        if (truncatedBPTT) {
            endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength);
        }
        INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
        INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey);
        INDArray bGradientsOut = gradientViews.get(biasWeightKey);
        iwGradientsOut.assign((Number)0);
        rwGradientsOut.assign((Number)0);
        bGradientsOut.assign((Number)0);
        INDArray rwGradientsIFOG = rwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(4 * hiddenLayerSize))});
        INDArray rwGradientsFF = null;
        INDArray rwGradientsOO = null;
        INDArray rwGradientsGG = null;
        if (hasPeepholeConnections) {
            rwGradientsFF = rwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize))});
            rwGradientsOO = rwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize + 1))});
            rwGradientsGG = rwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize + 2))});
        }
        if (helper != null && (ret = helper.backpropGradient(conf, gateActivationFn, input, recurrentWeights, inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, forwards, inputWeightKey, recurrentWeightKey, biasWeightKey, gradientViews, maskArray, hasPeepholeConnections)) != null) {
            return ret;
        }
        boolean sigmoidGates = gateActivationFn instanceof ActivationSigmoid;
        IActivation afn = ((org.deeplearning4j.nn.conf.layers.BaseLayer)conf.getLayer()).getActivationFn();
        MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace() != null && !Nd4j.getMemoryManager().getCurrentWorkspace().getId().equals("LOOP_EXTERNAL") ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceConfigurationLSTM, "LOOP_LSTM") : null;
        INDArray timeStepMaskColumn = null;
        for (int iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; --iTimeIndex) {
            INDArray deltaog;
            INDArray nablaCellState;
            if (workspace != null) {
                workspace.notifyScopeEntered();
            }
            int time = iTimeIndex;
            int inext = 1;
            if (!forwards) {
                time = timeSeriesLength - iTimeIndex - 1;
                inext = -1;
            }
            if (iTimeIndex != timeSeriesLength - 1 && hasPeepholeConnections) {
                nablaCellState = deltafNext.dup('f').muliRowVector(wFFTranspose);
                l1BLAS.axpy(nablaCellState.length(), 1.0, deltagNext.dup('f').muliRowVector(wGGTranspose), nablaCellState);
            } else {
                nablaCellState = Nd4j.create((int[])new int[]{miniBatchSize, hiddenLayerSize}, (char)'f');
            }
            INDArray prevMemCellState = iTimeIndex == 0 ? fwdPass.prevMemCell : fwdPass.memCellState[time - inext];
            INDArray prevHiddenUnitActivation = iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[time - inext];
            INDArray currMemCellState = fwdPass.memCellState[time];
            INDArray epsilonSlice = is2dInput ? epsilon : epsilon.tensorAlongDimension(time, new int[]{1, 0});
            INDArray nablaOut = Shape.toOffsetZeroCopy((INDArray)epsilonSlice, (char)'f');
            if (iTimeIndex != timeSeriesLength - 1) {
                Nd4j.gemm((INDArray)deltaifogNext, (INDArray)wIFOG, (INDArray)nablaOut, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
            }
            INDArray sigmahOfS = fwdPass.memCellActivations[time];
            INDArray ao = fwdPass.oa[time];
            INDArray deltao = deltaoNext;
            Nd4j.getExecutioner().exec((Op)new OldMulOp(nablaOut, sigmahOfS, deltao));
            if (sigmoidGates) {
                INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().execAndReturn((TransformOp)new TimesOneMinus(ao.dup('f')));
                deltao.muli(sigmaoPrimeOfZo);
            } else {
                deltao.assign((INDArray)gateActivationFn.backprop(fwdPass.oz[time], deltao).getFirst());
            }
            INDArray temp = (INDArray)afn.backprop(currMemCellState.dup('f'), ao.muli(nablaOut)).getFirst();
            l1BLAS.axpy(nablaCellState.length(), 1.0, temp, nablaCellState);
            if (hasPeepholeConnections) {
                INDArray deltaMulRowWOO = deltao.dup('f').muliRowVector(wOOTranspose);
                l1BLAS.axpy(nablaCellState.length(), 1.0, deltaMulRowWOO, nablaCellState);
            }
            if (iTimeIndex != timeSeriesLength - 1) {
                INDArray nextForgetGateAs = fwdPass.fa[time + inext];
                int length = nablaCellState.length();
                l1BLAS.axpy(length, 1.0, nextForgetGateAs.muli(nablaCellStateNext), nablaCellState);
            }
            nablaCellStateNext = workspace == null ? nablaCellState : nablaCellState.leverage();
            INDArray af = fwdPass.fa[time];
            INDArray deltaf = null;
            if (iTimeIndex > 0 || prevMemCellState != null) {
                deltaf = deltafNext;
                if (sigmoidGates) {
                    Nd4j.getExecutioner().exec((Op)new TimesOneMinus(af, deltaf));
                    deltaf.muli(nablaCellState);
                    deltaf.muli(prevMemCellState);
                } else {
                    INDArray temp2 = nablaCellState.mul(prevMemCellState);
                    deltaf.assign((INDArray)gateActivationFn.backprop(fwdPass.fz[time].dup('f'), temp2).getFirst());
                }
            }
            INDArray ag = fwdPass.ga[time];
            INDArray ai = fwdPass.ia[time];
            INDArray deltag = deltagNext;
            if (sigmoidGates) {
                Nd4j.getExecutioner().exec((Op)new TimesOneMinus(ag, deltag));
                deltag.muli(ai);
                deltag.muli(nablaCellState);
            } else {
                INDArray temp2 = Nd4j.getExecutioner().execAndReturn((TransformOp)new OldMulOp(ai, nablaCellState, Nd4j.createUninitialized((int[])ai.shape(), (char)'f')));
                deltag.assign((INDArray)gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst());
            }
            INDArray zi = fwdPass.iz[time];
            INDArray deltai = deltaiNext;
            temp = Nd4j.getExecutioner().execAndReturn((TransformOp)new OldMulOp(ag, nablaCellState, Nd4j.createUninitialized((int[])deltai.shape(), (char)'f')));
            deltai.assign((INDArray)afn.backprop(zi, temp).getFirst());
            if (maskArray != null) {
                timeStepMaskColumn = maskArray.getColumn(time);
                deltaifogNext.muliColumnVector(timeStepMaskColumn);
            }
            INDArray prevLayerActivationSlice = Shape.toMmulCompatible((INDArray)(is2dInput ? input : input.tensorAlongDimension(time, new int[]{1, 0})));
            if (iTimeIndex > 0 || prevHiddenUnitActivation != null) {
                Nd4j.gemm((INDArray)prevLayerActivationSlice, (INDArray)deltaifogNext, (INDArray)iwGradientsOut, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            } else {
                INDArray iwGradients_i = iwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)});
                Nd4j.gemm((INDArray)prevLayerActivationSlice, (INDArray)deltai, (INDArray)iwGradients_i, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
                INDArray iwGradients_og = iwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
                deltaog = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
                Nd4j.gemm((INDArray)prevLayerActivationSlice, (INDArray)deltaog, (INDArray)iwGradients_og, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            }
            if (iTimeIndex > 0 || prevHiddenUnitActivation != null) {
                Nd4j.gemm((INDArray)prevHiddenUnitActivation, (INDArray)deltaifogNext, (INDArray)rwGradientsIFOG, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
                if (hasPeepholeConnections) {
                    INDArray dLdwFF = deltaf.dup('f').muli(prevMemCellState).sum(new int[]{0});
                    l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwFF, rwGradientsFF);
                    INDArray dLdwGG = deltag.dup('f').muli(prevMemCellState).sum(new int[]{0});
                    l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwGG, rwGradientsGG);
                }
            }
            if (hasPeepholeConnections) {
                INDArray dLdwOO = deltao.dup('f').muli(currMemCellState).sum(new int[]{0});
                l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwOO, rwGradientsOO);
            }
            if (iTimeIndex > 0 || prevHiddenUnitActivation != null) {
                l1BLAS.axpy(4 * hiddenLayerSize, 1.0, deltaifogNext.sum(new int[]{0}), bGradientsOut);
            } else {
                l1BLAS.axpy(hiddenLayerSize, 1.0, deltai.sum(new int[]{0}), bGradientsOut);
                INDArray ogBiasToAdd = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))}).sum(new int[]{0});
                INDArray ogBiasGrad = bGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
                l1BLAS.axpy(2 * hiddenLayerSize, 1.0, ogBiasToAdd, ogBiasGrad);
            }
            INDArray epsilonNextSlice = epsilonNext.tensorAlongDimension(time, new int[]{1, 0});
            if (iTimeIndex > 0 || prevHiddenUnitActivation != null) {
                Nd4j.gemm((INDArray)deltaifogNext, (INDArray)inputWeights, (INDArray)epsilonNextSlice, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
            } else {
                INDArray wi = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)});
                Nd4j.gemm((INDArray)deltai, (INDArray)wi, (INDArray)epsilonNextSlice, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
                deltaog = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
                INDArray wog = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
                Nd4j.gemm((INDArray)deltaog, (INDArray)wog, (INDArray)epsilonNextSlice, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
            }
            if (maskArray != null) {
                epsilonNextSlice.muliColumnVector(timeStepMaskColumn);
            }
            if (workspace == null) continue;
            workspace.close();
        }
        DefaultGradient retGradient = new DefaultGradient();
        retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);
        retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);
        retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);
        return new Pair((Object)retGradient, (Object)epsilonNext);
    }

    public static LayerMemoryReport getMemoryReport(AbstractLSTM lstmLayer, InputType inputType) {
        boolean isGraves = lstmLayer instanceof GravesLSTM;
        return LSTMHelpers.getMemoryReport(isGraves, lstmLayer, inputType);
    }

    public static LayerMemoryReport getMemoryReport(GravesBidirectionalLSTM lstmLayer, InputType inputType) {
        LayerMemoryReport r = LSTMHelpers.getMemoryReport(true, lstmLayer, inputType);
        HashMap<CacheMode, Long> fixedTrain = new HashMap<CacheMode, Long>();
        HashMap<CacheMode, Long> varTrain = new HashMap<CacheMode, Long>();
        HashMap<CacheMode, Long> cacheFixed = new HashMap<CacheMode, Long>();
        HashMap<CacheMode, Long> cacheVar = new HashMap<CacheMode, Long>();
        for (CacheMode cm : CacheMode.values()) {
            fixedTrain.put(cm, 2L * r.getWorkingMemoryFixedTrain().get((Object)cm));
            varTrain.put(cm, 2L * r.getWorkingMemoryVariableTrain().get((Object)cm));
            cacheFixed.put(cm, 2L * r.getCacheModeMemFixed().get((Object)cm));
            cacheVar.put(cm, 2L * r.getCacheModeMemVariablePerEx().get((Object)cm));
        }
        return new LayerMemoryReport.Builder(r.getLayerName(), r.getClass(), r.getInputType(), r.getOutputType()).standardMemory(2L * r.getParameterSize(), 2L * r.getUpdaterStateSize()).workingMemory(2L * r.getWorkingMemoryFixedInference(), 2L * r.getWorkingMemoryVariableInference(), fixedTrain, varTrain).cacheMemory(cacheFixed, cacheVar).build();
    }

    public static LayerMemoryReport getMemoryReport(boolean isGraves, FeedForwardLayer lstmLayer, InputType inputType) {
        InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent)inputType;
        int tsLength = itr.getTimeSeriesLength();
        InputType outputType = lstmLayer.getOutputType(-1, inputType);
        int numParams = lstmLayer.initializer().numParams(lstmLayer);
        int updaterSize = (int)lstmLayer.getIUpdater().stateSize((long)numParams);
        int workingMemInferencePerEx = tsLength * 4 * lstmLayer.getNOut();
        int fwdPassPerTimeStepTrainCache = tsLength * 6 * lstmLayer.getNOut();
        int backpropWorkingSpace = (isGraves ? 9 : 6) * tsLength * lstmLayer.getNOut();
        HashMap<CacheMode, Long> trainVariable = new HashMap<CacheMode, Long>();
        HashMap<CacheMode, Long> cacheVariable = new HashMap<CacheMode, Long>();
        for (CacheMode cm : CacheMode.values()) {
            long cacheMem;
            long trainWorking;
            if (cm == CacheMode.NONE) {
                trainWorking = workingMemInferencePerEx + fwdPassPerTimeStepTrainCache + backpropWorkingSpace;
                cacheMem = 0L;
            } else {
                trainWorking = workingMemInferencePerEx + backpropWorkingSpace;
                cacheMem = fwdPassPerTimeStepTrainCache;
            }
            trainVariable.put(cm, trainWorking);
            cacheVariable.put(cm, cacheMem);
        }
        return new LayerMemoryReport.Builder(null, lstmLayer.getClass(), inputType, outputType).standardMemory(numParams, updaterSize).workingMemory(0L, (long)workingMemInferencePerEx, MemoryReport.CACHE_MODE_ALL_ZEROS, trainVariable).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, cacheVariable).build();
    }
}

