/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

public class GRU
extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.GRU> {
    public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";

    public GRU(NeuralNetConfiguration conf) {
        super(conf);
        throw new UnsupportedOperationException("GRU layer disabled: Backprop implementation is incorrect in this version. Consider using GravesLSTM instead");
    }

    public GRU(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
        throw new UnsupportedOperationException("GRU layer disabled: Backprop implementation is incorrect in this version. Consider using GravesLSTM instead");
    }

    @Override
    public Gradient gradient() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Gradient calcGradient(Gradient layerError, INDArray activation) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        INDArray[] activations = this.activateHelper(true, null);
        INDArray outputActivations = activations[0];
        INDArray rucZs = activations[1];
        INDArray rucAs = activations[2];
        INDArray inputWeights = this.getParam("W");
        INDArray recurrentWeights = this.getParam("RW");
        int layerSize = recurrentWeights.size(0);
        int prevLayerSize = inputWeights.size(0);
        int miniBatchSize = epsilon.size(0);
        boolean is2dInput = epsilon.rank() < 3;
        int timeSeriesLength = is2dInput ? 1 : epsilon.size(2);
        INDArray wr = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)layerSize)});
        INDArray wu = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))});
        INDArray wc = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))});
        INDArray wR = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)layerSize)});
        INDArray wU = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))});
        INDArray wC = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))});
        INDArray wRdiag = Nd4j.diag((INDArray)wR).transpose();
        INDArray wCdiag = Nd4j.diag((INDArray)wC).transpose();
        INDArray biasGradients = Nd4j.zeros((int[])new int[]{1, 3 * layerSize});
        INDArray inputWeightGradients = Nd4j.zeros((int[])new int[]{prevLayerSize, 3 * layerSize});
        INDArray recurrentWeightGradients = Nd4j.zeros((int[])new int[]{layerSize, 3 * layerSize});
        INDArray epsilonNext = Nd4j.zeros((int[])new int[]{miniBatchSize, prevLayerSize, timeSeriesLength});
        INDArray deltaOutNext = Nd4j.zeros((int)miniBatchSize, (int)layerSize);
        for (int t = timeSeriesLength - 1; t >= 0; --t) {
            INDArray zSliceNext;
            INDArray aSliceNext;
            INDArray zSlice;
            INDArray prevOut = t == 0 ? Nd4j.zeros((int)miniBatchSize, (int)layerSize) : outputActivations.tensorAlongDimension(t - 1, new int[]{1, 0});
            INDArray aSlice = is2dInput ? rucAs : rucAs.tensorAlongDimension(t, new int[]{1, 0});
            INDArray iNDArray = zSlice = is2dInput ? rucZs : rucZs.tensorAlongDimension(t, new int[]{1, 0});
            if (t == timeSeriesLength - 1) {
                aSliceNext = Nd4j.zeros((int)miniBatchSize, (int)(3 * layerSize));
                zSliceNext = Nd4j.zeros((int)miniBatchSize, (int)(3 * layerSize));
            } else {
                aSliceNext = rucAs.tensorAlongDimension(t + 1, new int[]{1, 0});
                zSliceNext = rucZs.tensorAlongDimension(t + 1, new int[]{1, 0});
            }
            INDArray zr = zSlice.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)layerSize)});
            INDArray sigmaPrimeZr = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", zr.dup()).derivative());
            INDArray epsilonSlice = is2dInput ? epsilon : epsilon.tensorAlongDimension(t, new int[]{1, 0});
            INDArray deltaOut = epsilonSlice.dup();
            if (t < timeSeriesLength - 1) {
                INDArray aOut = is2dInput ? outputActivations : outputActivations.tensorAlongDimension(t, new int[]{1, 0});
                INDArray arNext = aSliceNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)layerSize)});
                INDArray auNext = aSliceNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))});
                INDArray acNext = aSliceNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))});
                INDArray zrNext = zSliceNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)layerSize)});
                INDArray zuNext = zSliceNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))});
                INDArray zcNext = zSliceNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))});
                INDArray sigmaPrimeZrNext = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", zrNext.dup()).derivative());
                INDArray sigmaPrimeZuNext = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", zuNext.dup()).derivative());
                INDArray sigmaPrimeZcNext = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), zcNext.dup()).derivative());
                deltaOut.addi(auNext.mul(deltaOutNext));
                deltaOut.addi(aOut.sub(acNext).muli(sigmaPrimeZuNext).muli(wU.mmul(deltaOutNext.transpose()).transpose()));
                deltaOut.addi(auNext.rsub((Number)1.0).muli(sigmaPrimeZcNext).muli(arNext.add(aOut.mul(sigmaPrimeZrNext).muliRowVector(wRdiag))).muli(wC.mmul(deltaOutNext.transpose()).transpose()));
            }
            INDArray zu = zSlice.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))});
            INDArray sigmaPrimeZu = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", zu.dup()).derivative());
            INDArray ac = aSlice.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))});
            INDArray deltaU = deltaOut.mul(sigmaPrimeZu).muli(prevOut.sub(ac));
            INDArray zc = zSlice.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))});
            INDArray sigmaPrimeZc = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), zc.dup()).derivative());
            INDArray au = aSlice.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))});
            INDArray deltaC = deltaOut.mul(sigmaPrimeZc).muli(au.rsub((Number)1.0));
            INDArray deltaR = deltaC.mulRowVector(wCdiag).muli(prevOut).muli(sigmaPrimeZr);
            INDArray prevLayerActivationSlice = is2dInput ? this.input : this.input.tensorAlongDimension(t, new int[]{1, 0});
            inputWeightGradients.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)layerSize)}).addi(deltaR.transpose().mmul(prevLayerActivationSlice).transpose());
            inputWeightGradients.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))}).addi(deltaU.transpose().mmul(prevLayerActivationSlice).transpose());
            inputWeightGradients.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))}).addi(deltaC.transpose().mmul(prevLayerActivationSlice).transpose());
            if (t > 0) {
                recurrentWeightGradients.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)layerSize)}).addi(deltaR.transpose().mmul(prevOut).transpose());
                recurrentWeightGradients.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))}).addi(deltaU.transpose().mmul(prevOut).transpose());
                INDArray ar = aSlice.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)layerSize)});
                recurrentWeightGradients.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))}).addi(deltaC.transpose().mmul(prevOut.mul(ar)).transpose());
            }
            biasGradients.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)layerSize)}).addi(deltaR.sum(new int[]{0}));
            biasGradients.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))}).addi(deltaU.sum(new int[]{0}));
            biasGradients.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))}).addi(deltaC.sum(new int[]{0}));
            INDArray epsilonNextSlice = wr.mmul(deltaR.transpose()).transpose().addi(wu.mmul(deltaU.transpose()).transpose()).addi(wc.mmul(deltaC.transpose()).transpose());
            epsilonNext.tensorAlongDimension(t, new int[]{1, 0}).assign(epsilonNextSlice);
            deltaOutNext = deltaOut;
        }
        DefaultGradient g = new DefaultGradient();
        g.setGradientFor("W", inputWeightGradients);
        g.setGradientFor("RW", recurrentWeightGradients);
        g.setGradientFor("b", biasGradients);
        return new Pair<Gradient, INDArray>(g, epsilonNext);
    }

    @Override
    public INDArray preOutput(INDArray x) {
        return this.activate(x, true);
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        return this.activate(x, training);
    }

    @Override
    public INDArray activate(INDArray input, boolean training) {
        this.setInput(input, training);
        return this.activateHelper(training, null)[0];
    }

    @Override
    public INDArray activate(INDArray input) {
        this.setInput(input);
        return this.activateHelper(true, null)[0];
    }

    @Override
    public INDArray activate(boolean training) {
        return this.activateHelper(training, null)[0];
    }

    @Override
    public INDArray activate() {
        return this.activateHelper(false, null)[0];
    }

    private INDArray[] activateHelper(boolean training, INDArray prevOutputActivations) {
        INDArray inputWeights = this.getParam("W");
        INDArray recurrentWeights = this.getParam("RW");
        INDArray biases = this.getParam("b");
        boolean is2dInput = this.input.rank() < 3;
        int timeSeriesLength = is2dInput ? 1 : this.input.size(2);
        int hiddenLayerSize = recurrentWeights.size(0);
        int miniBatchSize = this.input.size(0);
        int layerSize = hiddenLayerSize;
        INDArray wr = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)layerSize)});
        INDArray wu = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))});
        INDArray wc = inputWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))});
        INDArray wR = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)layerSize)});
        INDArray wU = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))});
        INDArray wC = recurrentWeights.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))});
        INDArray br = biases.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)layerSize)});
        INDArray bu = biases.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)layerSize, (int)(2 * layerSize))});
        INDArray bc = biases.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(2 * layerSize), (int)(3 * layerSize))});
        if (this.conf.isUseDropConnect() && training && this.conf.getLayer().getDropOut() > 0.0) {
            inputWeights = Dropout.applyDropConnect(this, "W");
        }
        INDArray outputActivations = Nd4j.zeros((int[])new int[]{miniBatchSize, hiddenLayerSize, timeSeriesLength});
        INDArray rucZs = Nd4j.zeros((int[])new int[]{miniBatchSize, 3 * hiddenLayerSize, timeSeriesLength});
        INDArray rucAs = Nd4j.zeros((int[])new int[]{miniBatchSize, 3 * hiddenLayerSize, timeSeriesLength});
        if (prevOutputActivations == null) {
            prevOutputActivations = Nd4j.zeros((int)miniBatchSize, (int)hiddenLayerSize);
        }
        for (int t = 0; t < timeSeriesLength; ++t) {
            INDArray prevLayerInputSlice;
            INDArray iNDArray = prevLayerInputSlice = is2dInput ? this.input : this.input.tensorAlongDimension(t, new int[]{1, 0});
            if (t > 0) {
                prevOutputActivations = outputActivations.tensorAlongDimension(t - 1, new int[]{1, 0});
            }
            INDArray zs = Nd4j.zeros((int)miniBatchSize, (int)(3 * hiddenLayerSize));
            INDArray as = Nd4j.zeros((int)miniBatchSize, (int)(3 * hiddenLayerSize));
            INDArray zr = prevLayerInputSlice.mmul(wr).addi(prevOutputActivations.mmul(wR)).addiRowVector(br);
            INDArray ar = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", zr.dup()));
            zs.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)}).assign(zr);
            as.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)hiddenLayerSize)}).assign(ar);
            INDArray zu = prevLayerInputSlice.mmul(wu).addi(prevOutputActivations.mmul(wU)).addiRowVector(bu);
            INDArray au = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", zu.dup()));
            zs.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))}).assign(zu);
            as.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))}).assign(au);
            INDArray zc = prevLayerInputSlice.mmul(wc).addi(prevOutputActivations.mul(ar).mmul(wC)).addiRowVector(bc);
            INDArray ac = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), zc.dup()));
            zs.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))}).assign(zc);
            as.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))}).assign(ac);
            INDArray aOut = au.mul(prevOutputActivations).addi(au.rsub((Number)1).mul(ac));
            rucZs.tensorAlongDimension(t, new int[]{1, 0}).assign(zs);
            rucAs.tensorAlongDimension(t, new int[]{1, 0}).assign(as);
            outputActivations.tensorAlongDimension(t, new int[]{1, 0}).assign(aOut);
        }
        return new INDArray[]{outputActivations, rucZs, rucAs};
    }

    @Override
    public INDArray activationMean() {
        return this.activate();
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getL2() <= 0.0) {
            return 0.0;
        }
        double l2 = Transforms.pow((INDArray)this.getParam("RW"), (Number)2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.pow((INDArray)this.getParam("W"), (Number)2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
        return 0.5 * this.conf.getL2() * l2;
    }

    @Override
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getL1() <= 0.0) {
            return 0.0;
        }
        double l1 = Transforms.abs((INDArray)this.getParam("RW")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.abs((INDArray)this.getParam("W")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
        return this.conf.getL1() * l1;
    }

    @Override
    public INDArray rnnTimeStep(INDArray input) {
        this.setInput(input);
        INDArray[] activations = this.activateHelper(false, (INDArray)this.stateMap.get(STATE_KEY_PREV_ACTIVATION));
        INDArray outAct = activations[0];
        int tLength = outAct.size(2);
        INDArray lastActSlice = outAct.tensorAlongDimension(tLength - 1, new int[]{1, 0});
        this.stateMap.put(STATE_KEY_PREV_ACTIVATION, lastActSlice.dup());
        return outAct;
    }

    @Override
    public INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT) {
        this.setInput(input);
        INDArray[] activations = this.activateHelper(false, (INDArray)this.stateMap.get(STATE_KEY_PREV_ACTIVATION));
        INDArray outAct = activations[0];
        if (storeLastForTBPTT) {
            int tLength = outAct.size(2);
            INDArray lastActSlice = outAct.tensorAlongDimension(tLength - 1, new int[]{1, 0});
            this.tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, lastActSlice.dup());
        }
        return outAct;
    }

    @Override
    public Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray epsilon, int tbpttBackwardLength) {
        throw new UnsupportedOperationException("Not yet implemented");
    }
}

