/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.TimesOneMinus;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class LSTMHelpers {
    public static FwdPassReturn activateHelper(Layer layer, NeuralNetConfiguration conf, INDArray input, INDArray recurrentWeights, INDArray originalInputWeights, INDArray biases, boolean training, INDArray originalPrevOutputActivations, INDArray originalPrevMemCellState, boolean forBackprop, boolean forwards, String inputWeightKey) {
        if (input == null || input.length() == 0) {
            throw new IllegalArgumentException("Invalid input: not set or 0 length");
        }
        INDArray inputWeights = originalInputWeights;
        INDArray prevOutputActivations = originalPrevOutputActivations;
        boolean is2dInput = input.rank() < 3;
        int timeSeriesLength = is2dInput ? 1 : input.size(2);
        int hiddenLayerSize = recurrentWeights.size(0);
        int miniBatchSize = input.size(0);
        INDArray prevMemCellState = originalPrevMemCellState == null ? Nd4j.create((int[])new int[]{miniBatchSize, hiddenLayerSize}, (char)'f') : originalPrevMemCellState.dup('f');
        INDArray recurrentWeightsIFOG = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(4 * hiddenLayerSize))}).dup('f');
        if (conf.isUseDropConnect() && training && conf.getLayer().getDropOut() > 0.0) {
            inputWeights = Dropout.applyDropConnect(layer, inputWeightKey);
        }
        INDArray wFFTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(4 * hiddenLayerSize), (int)(4 * hiddenLayerSize + 1))}).transpose();
        INDArray wOOTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(4 * hiddenLayerSize + 1), (int)(4 * hiddenLayerSize + 2))}).transpose();
        INDArray wGGTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(4 * hiddenLayerSize + 2), (int)(4 * hiddenLayerSize + 3))}).transpose();
        if (timeSeriesLength > 1 || forBackprop) {
            wFFTranspose = Shape.toMmulCompatible((INDArray)wFFTranspose);
            wOOTranspose = Shape.toMmulCompatible((INDArray)wOOTranspose);
            wGGTranspose = Shape.toMmulCompatible((INDArray)wGGTranspose);
        }
        INDArray outputActivations = null;
        FwdPassReturn toReturn = new FwdPassReturn();
        if (forBackprop) {
            toReturn.fwdPassOutputAsArrays = new INDArray[timeSeriesLength];
            toReturn.memCellState = new INDArray[timeSeriesLength];
            toReturn.memCellActivations = new INDArray[timeSeriesLength];
            toReturn.iz = new INDArray[timeSeriesLength];
            toReturn.ia = new INDArray[timeSeriesLength];
            toReturn.fa = new INDArray[timeSeriesLength];
            toReturn.oa = new INDArray[timeSeriesLength];
            toReturn.ga = new INDArray[timeSeriesLength];
        } else {
            toReturn.fwdPassOutput = outputActivations = Nd4j.create((int[])new int[]{miniBatchSize, hiddenLayerSize, timeSeriesLength}, (char)'f');
        }
        Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
        if (prevOutputActivations == null) {
            prevOutputActivations = Nd4j.zeros((int[])new int[]{miniBatchSize, hiddenLayerSize});
        }
        for (int iTimeIndex = 0; iTimeIndex < timeSeriesLength; ++iTimeIndex) {
            INDArray inputModMulInput;
            INDArray currentMemoryCellState;
            int time = iTimeIndex;
            if (!forwards) {
                time = timeSeriesLength - iTimeIndex - 1;
            }
            INDArray miniBatchData = is2dInput ? input : input.tensorAlongDimension(time, new int[]{1, 0});
            miniBatchData = Shape.toMmulCompatible((INDArray)miniBatchData);
            INDArray ifogActivations = miniBatchData.mmul(inputWeights);
            Nd4j.gemm((INDArray)prevOutputActivations, (INDArray)recurrentWeightsIFOG, (INDArray)ifogActivations, (boolean)false, (boolean)false, (double)1.0, (double)1.0);
            ifogActivations.addiRowVector(biases);
            INDArray inputActivations = ifogActivations.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)});
            if (forBackprop) {
                toReturn.iz[time] = inputActivations.dup('f');
            }
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf.getLayer().getActivationFunction(), inputActivations));
            if (forBackprop) {
                toReturn.ia[time] = inputActivations;
            }
            INDArray forgetGateActivations = ifogActivations.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))});
            INDArray pmcellWFF = prevMemCellState.dup('f').muliRowVector(wFFTranspose);
            l1BLAS.axpy(pmcellWFF.length(), 1.0, pmcellWFF, forgetGateActivations);
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", forgetGateActivations));
            if (forBackprop) {
                toReturn.fa[time] = forgetGateActivations;
            }
            INDArray inputModGateActivations = ifogActivations.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
            INDArray pmcellWGG = prevMemCellState.dup('f').muliRowVector(wGGTranspose);
            l1BLAS.axpy(pmcellWGG.length(), 1.0, pmcellWGG, inputModGateActivations);
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", inputModGateActivations));
            if (forBackprop) {
                toReturn.ga[time] = inputModGateActivations;
            }
            if (forBackprop) {
                currentMemoryCellState = prevMemCellState.dup('f').muli(forgetGateActivations);
                inputModMulInput = inputModGateActivations.dup('f').muli(inputActivations);
            } else {
                currentMemoryCellState = forgetGateActivations.muli(prevMemCellState);
                inputModMulInput = inputModGateActivations.muli(inputActivations);
            }
            l1BLAS.axpy(currentMemoryCellState.length(), 1.0, inputModMulInput, currentMemoryCellState);
            INDArray outputGateActivations = ifogActivations.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))});
            INDArray pmcellWOO = currentMemoryCellState.dup('f').muliRowVector(wOOTranspose);
            l1BLAS.axpy(pmcellWOO.length(), 1.0, pmcellWOO, outputGateActivations);
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", outputGateActivations));
            if (forBackprop) {
                toReturn.oa[time] = outputGateActivations;
            }
            INDArray currMemoryCellActivation = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf.getLayer().getActivationFunction(), currentMemoryCellState.dup('f')));
            INDArray currHiddenUnitActivations = forBackprop ? currMemoryCellActivation.dup('f').muli(outputGateActivations) : currMemoryCellActivation.muli(outputGateActivations);
            if (forBackprop) {
                toReturn.fwdPassOutputAsArrays[time] = currHiddenUnitActivations;
                toReturn.memCellState[time] = currentMemoryCellState;
                toReturn.memCellActivations[time] = currMemoryCellActivation;
            } else {
                outputActivations.tensorAlongDimension(time, new int[]{1, 0}).assign(currHiddenUnitActivations);
            }
            prevOutputActivations = currHiddenUnitActivations;
            prevMemCellState = currentMemoryCellState;
            toReturn.lastAct = currHiddenUnitActivations;
            toReturn.lastMemCell = currentMemoryCellState;
        }
        return toReturn;
    }

    public static Pair<Gradient, INDArray> backpropGradientHelper(NeuralNetConfiguration conf, INDArray input, INDArray recurrentWeights, INDArray inputWeights, INDArray epsilon, boolean truncatedBPTT, int tbpttBackwardLength, FwdPassReturn fwdPass, boolean forwards, String inputWeightKey, String recurrentWeightKey, String biasWeightKey, Map<String, INDArray> gradientViews) {
        int hiddenLayerSize = recurrentWeights.size(0);
        int prevLayerSize = inputWeights.size(0);
        int miniBatchSize = epsilon.size(0);
        boolean is2dInput = epsilon.rank() < 3;
        int timeSeriesLength = is2dInput ? 1 : epsilon.size(2);
        INDArray wFFTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize))}).transpose();
        INDArray wOOTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize + 1))}).transpose();
        INDArray wGGTranspose = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize + 2))}).transpose();
        INDArray wIFOG = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(4 * hiddenLayerSize))});
        INDArray epsilonNext = Nd4j.create((int[])new int[]{miniBatchSize, prevLayerSize, timeSeriesLength}, (char)'f');
        INDArray nablaCellStateNext = null;
        INDArray deltaifogNext = Nd4j.create((int[])new int[]{miniBatchSize, 4 * hiddenLayerSize}, (char)'f');
        INDArray deltaiNext = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)});
        INDArray deltafNext = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))});
        INDArray deltaoNext = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))});
        INDArray deltagNext = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
        Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
        int endIdx = 0;
        if (truncatedBPTT) {
            endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength);
        }
        INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
        INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey);
        INDArray bGradientsOut = gradientViews.get(biasWeightKey);
        iwGradientsOut.assign((Number)0);
        rwGradientsOut.assign((Number)0);
        bGradientsOut.assign((Number)0);
        INDArray rwGradientsIFOG = rwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(4 * hiddenLayerSize))});
        INDArray rwGradientsFF = rwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize))});
        INDArray rwGradientsOO = rwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize + 1))});
        INDArray rwGradientsGG = rwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((int)(4 * hiddenLayerSize + 2))});
        for (int iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; --iTimeIndex) {
            INDArray nablaCellState;
            int time = iTimeIndex;
            int inext = 1;
            if (!forwards) {
                time = timeSeriesLength - iTimeIndex - 1;
                inext = -1;
            }
            if (iTimeIndex != timeSeriesLength - 1) {
                nablaCellState = deltafNext.dup('f').muliRowVector(wFFTranspose);
                l1BLAS.axpy(nablaCellState.length(), 1.0, deltagNext.dup('f').muliRowVector(wGGTranspose), nablaCellState);
            } else {
                nablaCellState = Nd4j.create((int[])new int[]{miniBatchSize, hiddenLayerSize}, (char)'f');
            }
            INDArray prevMemCellState = iTimeIndex == 0 ? null : fwdPass.memCellState[time - inext];
            INDArray prevHiddenUnitActivation = iTimeIndex == 0 ? null : fwdPass.fwdPassOutputAsArrays[time - inext];
            INDArray currMemCellState = fwdPass.memCellState[time];
            INDArray epsilonSlice = is2dInput ? epsilon : epsilon.tensorAlongDimension(time, new int[]{1, 0});
            INDArray nablaOut = Shape.toOffsetZeroCopy((INDArray)epsilonSlice, (char)'f');
            if (iTimeIndex != timeSeriesLength - 1) {
                Nd4j.gemm((INDArray)deltaifogNext, (INDArray)wIFOG, (INDArray)nablaOut, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
            }
            INDArray sigmahOfS = fwdPass.memCellActivations[time];
            INDArray ao = fwdPass.oa[time];
            INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("timesoneminus", ao.dup('f')));
            INDArray deltao = deltaoNext;
            Nd4j.getExecutioner().exec((Op)new MulOp(nablaOut, sigmahOfS, deltao));
            deltao.muli(sigmaoPrimeOfZo);
            INDArray sigmahPrimeOfS = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf.getLayer().getActivationFunction(), currMemCellState.dup('f')).derivative());
            l1BLAS.axpy(nablaCellState.length(), 1.0, ao.muli(nablaOut).muli(sigmahPrimeOfS), nablaCellState);
            INDArray deltaMulRowWOO = deltao.dup('f').muliRowVector(wOOTranspose);
            l1BLAS.axpy(nablaCellState.length(), 1.0, deltaMulRowWOO, nablaCellState);
            if (iTimeIndex != timeSeriesLength - 1) {
                INDArray nextForgetGateAs = fwdPass.fa[time + inext];
                int length = nablaCellState.length();
                l1BLAS.axpy(length, 1.0, nextForgetGateAs.muli(nablaCellStateNext), nablaCellState);
            }
            nablaCellStateNext = nablaCellState;
            INDArray af = fwdPass.fa[time];
            INDArray deltaf = null;
            if (iTimeIndex > 0) {
                deltaf = deltafNext;
                Nd4j.getExecutioner().exec((Op)new TimesOneMinus(af, deltaf));
                deltaf.muli(nablaCellState);
                deltaf.muli(prevMemCellState);
            }
            INDArray ag = fwdPass.ga[time];
            INDArray ai = fwdPass.ia[time];
            INDArray deltag = deltagNext;
            Nd4j.getExecutioner().exec((Op)new TimesOneMinus(ag, deltag));
            deltag.muli(ai);
            deltag.muli(nablaCellState);
            INDArray zi = fwdPass.iz[time];
            INDArray deltai = deltaiNext;
            Nd4j.getExecutioner().exec((Op)Nd4j.getOpFactory().createTransform(conf.getLayer().getActivationFunction(), zi, null, deltai).derivative());
            deltai.muli(ag);
            deltai.muli(nablaCellState);
            INDArray prevLayerActivationSlice = Shape.toMmulCompatible((INDArray)(is2dInput ? input : input.tensorAlongDimension(time, new int[]{1, 0})));
            if (iTimeIndex > 0) {
                Nd4j.gemm((INDArray)prevLayerActivationSlice, (INDArray)deltaifogNext, (INDArray)iwGradientsOut, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            } else {
                INDArray iwGradients_i = iwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)});
                Nd4j.gemm((INDArray)prevLayerActivationSlice, (INDArray)deltai, (INDArray)iwGradients_i, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
                INDArray iwGradients_og = iwGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
                INDArray deltaog = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
                Nd4j.gemm((INDArray)prevLayerActivationSlice, (INDArray)deltaog, (INDArray)iwGradients_og, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            }
            if (iTimeIndex > 0) {
                Nd4j.gemm((INDArray)prevHiddenUnitActivation, (INDArray)deltaifogNext, (INDArray)rwGradientsIFOG, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
                INDArray dLdwFF = deltaf.dup('f').muli(prevMemCellState).sum(new int[]{0});
                l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwFF, rwGradientsFF);
                INDArray dLdwGG = deltag.dup('f').muli(prevMemCellState).sum(new int[]{0});
                l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwGG, rwGradientsGG);
            }
            INDArray dLdwOO = deltao.dup('f').muli(currMemCellState).sum(new int[]{0});
            l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwOO, rwGradientsOO);
            if (iTimeIndex > 0) {
                l1BLAS.axpy(4 * hiddenLayerSize, 1.0, deltaifogNext.sum(new int[]{0}), bGradientsOut);
            } else {
                l1BLAS.axpy(hiddenLayerSize, 1.0, deltai.sum(new int[]{0}), bGradientsOut);
                INDArray ogBiasToAdd = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))}).sum(new int[]{0});
                INDArray ogBiasGrad = bGradientsOut.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
                l1BLAS.axpy(2 * hiddenLayerSize, 1.0, ogBiasToAdd, ogBiasGrad);
            }
            INDArray epsilonNextSlice = epsilonNext.tensorAlongDimension(time, new int[]{1, 0});
            if (iTimeIndex > 0) {
                Nd4j.gemm((INDArray)deltaifogNext, (INDArray)inputWeights, (INDArray)epsilonNextSlice, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
                continue;
            }
            INDArray wi = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)});
            Nd4j.gemm((INDArray)deltai, (INDArray)wi, (INDArray)epsilonNextSlice, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
            INDArray deltaog = deltaifogNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
            INDArray wog = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(4 * hiddenLayerSize))});
            Nd4j.gemm((INDArray)deltaog, (INDArray)wog, (INDArray)epsilonNextSlice, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
        }
        DefaultGradient retGradient = new DefaultGradient();
        retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);
        retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);
        retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);
        return new Pair<Gradient, INDArray>(retGradient, epsilonNext);
    }
}

