/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

public class SimpleRnn
extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn> {
    public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";

    public SimpleRnn(NeuralNetConfiguration conf) {
        super(conf);
    }

    @Override
    public INDArray rnnTimeStep(INDArray input) {
        this.setInput(input);
        INDArray last = (INDArray)this.stateMap.get(STATE_KEY_PREV_ACTIVATION);
        INDArray out = (INDArray)this.activateHelper(last, false, false).getFirst();
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            this.stateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((int)(out.size(2) - 1))}));
        }
        return out;
    }

    @Override
    public INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT) {
        this.setInput(input);
        INDArray last = (INDArray)this.tBpttStateMap.get(STATE_KEY_PREV_ACTIVATION);
        INDArray out = (INDArray)this.activateHelper(last, training, false).getFirst();
        if (storeLastForTBPTT) {
            try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                this.tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((int)(out.size(2) - 1))}));
            }
        }
        return out;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        return this.tbpttBackpropGradient(epsilon, -1);
    }

    @Override
    public Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray epsilon, int tbpttBackLength) {
        epsilon = epsilon.dup('f');
        Pair<INDArray, INDArray> p = this.activateHelper(null, true, true);
        INDArray w = this.getParamWithNoise("W", true);
        INDArray rw = this.getParamWithNoise("RW", true);
        INDArray wg = (INDArray)this.gradientViews.get("W");
        INDArray rwg = (INDArray)this.gradientViews.get("RW");
        INDArray bg = (INDArray)this.gradientViews.get("b");
        this.gradientsFlattened.assign((Number)0);
        IActivation a = ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn)this.layerConf()).getActivationFn();
        int tsLength = this.input.size(2);
        INDArray epsOut = Nd4j.createUninitialized((int[])this.input.shape(), (char)'f');
        INDArray dldzNext = null;
        int end = tbpttBackLength > 0 ? Math.max(0, tsLength - tbpttBackLength) : 0;
        for (int i = tsLength - 1; i >= end; --i) {
            INDArray dldaCurrent = epsilon.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((int)i)});
            INDArray aCurrent = ((INDArray)p.getFirst()).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((int)i)});
            INDArray zCurrent = ((INDArray)p.getSecond()).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((int)i)});
            INDArray inCurrent = this.input.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((int)i)});
            INDArray epsOutCurrent = epsOut.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((int)i)});
            if (dldzNext != null) {
                Nd4j.gemm(dldzNext, (INDArray)rw, (INDArray)dldaCurrent, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
            }
            INDArray dldzCurrent = (INDArray)a.backprop(zCurrent.dup(), dldaCurrent.dup()).getFirst();
            INDArray maskCol = null;
            if (this.maskArray != null) {
                maskCol = this.maskArray.getColumn(i);
                dldzCurrent.muliColumnVector(maskCol);
            }
            Nd4j.gemm((INDArray)inCurrent, (INDArray)dldzCurrent, (INDArray)wg, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            if (dldzNext != null) {
                Nd4j.gemm((INDArray)aCurrent, dldzNext, (INDArray)rwg, (boolean)true, (boolean)false, (double)1.0, (double)1.0);
            }
            bg.addi(dldzCurrent.sum(new int[]{0}));
            Nd4j.gemm((INDArray)dldzCurrent, (INDArray)w, (INDArray)epsOutCurrent, (boolean)false, (boolean)true, (double)1.0, (double)0.0);
            dldzNext = dldzCurrent;
            if (this.maskArray == null) continue;
            epsOutCurrent.muliColumnVector(maskCol);
        }
        this.weightNoiseParams.clear();
        DefaultGradient g = new DefaultGradient(this.gradientsFlattened);
        g.gradientForVariable().put("W", wg);
        g.gradientForVariable().put("RW", rwg);
        g.gradientForVariable().put("b", bg);
        return new Pair((Object)g, (Object)epsOut);
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public INDArray preOutput(boolean training) {
        return this.activate(training);
    }

    @Override
    public INDArray activate(boolean training) {
        return (INDArray)this.activateHelper(null, training, false).getFirst();
    }

    private Pair<INDArray, INDArray> activateHelper(INDArray prevStepOut, boolean training, boolean forBackprop) {
        INDArray outZ;
        this.applyDropOutIfNecessary(training);
        int m = this.input.size(0);
        int tsLength = this.input.size(2);
        int nOut = ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn)this.layerConf()).getNOut();
        INDArray w = this.getParamWithNoise("W", training);
        INDArray rw = this.getParamWithNoise("RW", training);
        INDArray b = this.getParamWithNoise("b", training);
        INDArray out = Nd4j.createUninitialized((int[])new int[]{m, nOut, tsLength}, (char)'f');
        INDArray iNDArray = outZ = forBackprop ? Nd4j.createUninitialized((int[])out.shape()) : null;
        if (this.input.ordering() != 'f' || Shape.strideDescendingCAscendingF((INDArray)this.input)) {
            this.input = this.input.dup('f');
        }
        Nd4j.getExecutioner().exec((Op)new BroadcastCopyOp(out, b, out, new int[]{1}));
        IActivation a = ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn)this.layerConf()).getActivationFn();
        for (int i = 0; i < tsLength; ++i) {
            INDArray currOut = out.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((int)i)});
            INDArray currIn = this.input.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((int)i)});
            Nd4j.gemm((INDArray)currIn, (INDArray)w, (INDArray)currOut, (boolean)false, (boolean)false, (double)1.0, (double)1.0);
            if (i > 0 || prevStepOut != null) {
                Nd4j.gemm((INDArray)prevStepOut, (INDArray)rw, (INDArray)currOut, (boolean)false, (boolean)false, (double)1.0, (double)1.0);
            }
            if (forBackprop) {
                outZ.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((int)i)}).assign(currOut);
            }
            a.getActivation(currOut, training);
            prevStepOut = currOut;
        }
        if (this.maskArray != null) {
            Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(out, this.maskArray, out, new int[]{0, 2}));
            if (forBackprop) {
                Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(outZ, this.maskArray, outZ, new int[]{0, 2}));
            }
        }
        return new Pair((Object)out, (Object)outZ);
    }
}

