/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

public class BidirectionalLayer
implements RecurrentLayer {
    private NeuralNetConfiguration conf;
    private RecurrentLayer fwd;
    private RecurrentLayer bwd;
    private Bidirectional layerConf;
    private INDArray paramsView;
    private INDArray gradientView;
    private transient Map<String, INDArray> gradientViews;
    private INDArray input;
    private INDArray outFwd;
    private INDArray outBwd;

    public BidirectionalLayer(@NonNull NeuralNetConfiguration conf, @NonNull RecurrentLayer fwd, @NonNull RecurrentLayer bwd) {
        if (conf == null) {
            throw new NullPointerException("conf");
        }
        if (fwd == null) {
            throw new NullPointerException("fwd");
        }
        if (bwd == null) {
            throw new NullPointerException("bwd");
        }
        this.conf = conf;
        this.fwd = fwd;
        this.bwd = bwd;
        this.layerConf = (Bidirectional)conf.getLayer();
    }

    @Override
    public INDArray rnnTimeStep(INDArray input) {
        throw new UnsupportedOperationException("Cannot RnnTimeStep bidirectional layers");
    }

    @Override
    public Map<String, INDArray> rnnGetPreviousState() {
        throw new UnsupportedOperationException("Not supported: cannot RnnTimeStep bidirectional layers therefore no previous state is supported");
    }

    @Override
    public void rnnSetPreviousState(Map<String, INDArray> stateMap) {
        throw new UnsupportedOperationException("Not supported: cannot RnnTimeStep bidirectional layers therefore no previous state is supported");
    }

    @Override
    public void rnnClearPreviousState() {
    }

    @Override
    public INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT) {
        throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers");
    }

    @Override
    public Map<String, INDArray> rnnGetTBPTTState() {
        throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers");
    }

    @Override
    public void rnnSetTBPTTState(Map<String, INDArray> state) {
        throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers");
    }

    @Override
    public Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray epsilon, int tbpttBackLength) {
        throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers");
    }

    @Override
    public void setCacheMode(CacheMode mode) {
        this.fwd.setCacheMode(mode);
        this.bwd.setCacheMode(mode);
    }

    @Override
    public double calcL2(boolean backpropOnlyParams) {
        return this.fwd.calcL2(backpropOnlyParams) + this.bwd.calcL2(backpropOnlyParams);
    }

    @Override
    public double calcL1(boolean backpropOnlyParams) {
        return this.fwd.calcL1(backpropOnlyParams) + this.bwd.calcL1(backpropOnlyParams);
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        INDArray eBwd;
        INDArray eFwd;
        int n = epsilon.size(1) / 2;
        switch (this.layerConf.getMode()) {
            case ADD: {
                eFwd = epsilon;
                eBwd = epsilon;
                break;
            }
            case MUL: {
                eFwd = epsilon.dup(epsilon.ordering()).muli(this.outBwd);
                eBwd = epsilon.dup(epsilon.ordering()).muli(this.outFwd);
                break;
            }
            case AVERAGE: {
                eBwd = eFwd = epsilon.dup(epsilon.ordering()).muli((Number)0.5);
                break;
            }
            case CONCAT: {
                eFwd = epsilon.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)n), NDArrayIndex.all()});
                eBwd = epsilon.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)n, (int)(2 * n)), NDArrayIndex.all()});
                break;
            }
            default: {
                throw new RuntimeException("Unknown mode: " + (Object)((Object)this.layerConf.getMode()));
            }
        }
        eBwd = TimeSeriesUtils.reverseTimeSeries(eBwd);
        Pair<Gradient, INDArray> g1 = this.fwd.backpropGradient(eFwd);
        Pair<Gradient, INDArray> g2 = this.bwd.backpropGradient(eBwd);
        DefaultGradient g = new DefaultGradient(this.gradientView);
        for (Map.Entry<String, INDArray> e : ((Gradient)g1.getFirst()).gradientForVariable().entrySet()) {
            g.gradientForVariable().put("f" + e.getKey(), e.getValue());
        }
        for (Map.Entry<String, INDArray> e : ((Gradient)g2.getFirst()).gradientForVariable().entrySet()) {
            g.gradientForVariable().put("b" + e.getKey(), e.getValue());
        }
        INDArray g2Reversed = TimeSeriesUtils.reverseTimeSeries((INDArray)g2.getRight());
        INDArray epsOut = ((INDArray)g1.getRight()).addi(g2Reversed);
        return new Pair((Object)g, (Object)epsOut);
    }

    @Override
    public INDArray preOutput(INDArray x) {
        return this.activate(x);
    }

    @Override
    public INDArray preOutput(INDArray x, Layer.TrainingMode training) {
        return this.activate(x, training);
    }

    @Override
    public INDArray activate(Layer.TrainingMode training) {
        return this.activate(training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray activate(INDArray input, Layer.TrainingMode training) {
        this.setInput(input);
        return this.activate(training);
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        return this.activate(x, training);
    }

    @Override
    public INDArray activate(boolean training) {
        INDArray out1 = this.fwd.activate(training);
        INDArray out2 = this.bwd.activate(training);
        out2 = TimeSeriesUtils.reverseTimeSeries(out2);
        switch (this.layerConf.getMode()) {
            case ADD: {
                return out1.addi(out2);
            }
            case MUL: {
                this.outFwd = out1.detach();
                this.outBwd = out2.detach();
                return out1.mul(out2);
            }
            case AVERAGE: {
                return out1.addi(out2).muli((Number)0.5);
            }
            case CONCAT: {
                return Nd4j.concat((int)1, (INDArray[])new INDArray[]{out1, out2});
            }
        }
        throw new RuntimeException("Unknown mode: " + (Object)((Object)this.layerConf.getMode()));
    }

    @Override
    public INDArray activate(INDArray input, boolean training) {
        this.setInput(input);
        return this.activate(training);
    }

    @Override
    public INDArray activate() {
        return this.activate(false);
    }

    @Override
    public INDArray activate(INDArray input) {
        this.setInput(input);
        return this.activate();
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException("Cannot transpose layer");
    }

    @Override
    public Layer clone() {
        throw new UnsupportedOperationException("Clone not supported");
    }

    @Override
    public Collection<IterationListener> getListeners() {
        return this.fwd.getListeners();
    }

    @Override
    public void setListeners(IterationListener ... listeners) {
        this.fwd.setListeners(listeners);
        this.bwd.setListeners(listeners);
    }

    @Override
    public void addListeners(IterationListener ... listener) {
        this.fwd.addListeners(listener);
        this.bwd.addListeners(listener);
    }

    @Override
    public void fit() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void update(Gradient gradient) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public double score() {
        return this.fwd.score() + this.bwd.score();
    }

    @Override
    public void computeGradientAndScore() {
        this.fwd.computeGradientAndScore();
        this.bwd.computeGradientAndScore();
    }

    @Override
    public void accumulateScore(double accum) {
        this.fwd.accumulateScore(accum);
        this.bwd.accumulateScore(accum);
    }

    @Override
    public INDArray params() {
        return this.paramsView;
    }

    @Override
    public int numParams() {
        return this.fwd.numParams() + this.bwd.numParams();
    }

    @Override
    public int numParams(boolean backwards) {
        return this.fwd.numParams(backwards) + this.bwd.numParams(backwards);
    }

    @Override
    public void setParams(INDArray params) {
        this.paramsView.assign(params);
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        this.paramsView = params;
        int n = params.length();
        this.fwd.setParamsViewArray(params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)n)}));
        this.bwd.setParamsViewArray(params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)n, (int)(2 * n))}));
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.gradientView;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        if (this.paramsView != null && gradients.length() != this.numParams()) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + this.numParams(true) + ", got array of length " + gradients.length());
        }
        this.gradientView = gradients;
        int n = gradients.length() / 2;
        INDArray g1 = gradients.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)n)});
        INDArray g2 = gradients.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)n, (int)(2 * n))});
        this.fwd.setBackpropGradientsViewArray(g1);
        this.bwd.setBackpropGradientsViewArray(g2);
    }

    @Override
    public void fit(INDArray data) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void iterate(INDArray input) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public Gradient gradient() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public int batchSize() {
        return this.fwd.batchSize();
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        this.conf = conf;
    }

    @Override
    public INDArray input() {
        return this.input;
    }

    @Override
    public void validateInput() {
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        return null;
    }

    @Override
    public INDArray getParam(String param) {
        String sub = param.substring(1);
        if (param.startsWith("f")) {
            return this.fwd.getParam(sub);
        }
        return this.bwd.getParam(sub);
    }

    @Override
    public void initParams() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.paramTable(false);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        LinkedHashMap<String, INDArray> m = new LinkedHashMap<String, INDArray>();
        for (Map.Entry<String, INDArray> e : this.fwd.paramTable(backpropParamsOnly).entrySet()) {
            m.put("f" + e.getKey(), e.getValue());
        }
        for (Map.Entry<String, INDArray> e : this.bwd.paramTable(backpropParamsOnly).entrySet()) {
            m.put("b" + e.getKey(), e.getValue());
        }
        return m;
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
            this.setParam(e.getKey(), e.getValue());
        }
    }

    @Override
    public void setParam(String key, INDArray val) {
        String sub = key.substring(1);
        if (key.startsWith("f")) {
            this.fwd.setParam(sub, val);
        } else {
            this.bwd.setParam(sub, val);
        }
    }

    @Override
    public void clear() {
        this.fwd.clear();
        this.bwd.clear();
        this.input = null;
        this.outFwd = null;
        this.outBwd = null;
    }

    @Override
    public void applyConstraints(int iteration, int epoch) {
        this.fwd.applyConstraints(iteration, epoch);
        this.bwd.applyConstraints(iteration, epoch);
    }

    @Override
    public void init() {
    }

    @Override
    public void setListeners(Collection<IterationListener> listeners) {
        this.fwd.setListeners(listeners);
        this.bwd.setListeners(listeners);
    }

    @Override
    public void setIndex(int index) {
        this.fwd.setIndex(index);
        this.bwd.setIndex(index);
    }

    @Override
    public int getIndex() {
        return this.fwd.getIndex();
    }

    @Override
    public int getIterationCount() {
        return this.fwd.getIterationCount();
    }

    @Override
    public int getEpochCount() {
        return this.fwd.getEpochCount();
    }

    @Override
    public void setIterationCount(int iterationCount) {
        this.fwd.setIterationCount(iterationCount);
        this.bwd.setIterationCount(iterationCount);
    }

    @Override
    public void setEpochCount(int epochCount) {
        this.fwd.setEpochCount(epochCount);
        this.bwd.setEpochCount(epochCount);
    }

    @Override
    public void setInput(INDArray input) {
        this.input = input;
        this.fwd.setInput(input);
        this.bwd.setInput(TimeSeriesUtils.reverseTimeSeries(input));
    }

    @Override
    public void migrateInput() {
        this.fwd.migrateInput();
        this.bwd.migrateInput();
    }

    @Override
    public void setInputMiniBatchSize(int size) {
        this.fwd.setInputMiniBatchSize(size);
        this.bwd.setInputMiniBatchSize(size);
    }

    @Override
    public int getInputMiniBatchSize() {
        return this.fwd.getInputMiniBatchSize();
    }

    @Override
    public void setMaskArray(INDArray maskArray) {
        this.fwd.setMaskArray(maskArray);
        this.bwd.setMaskArray(TimeSeriesUtils.reverseTimeSeriesMask(maskArray));
    }

    @Override
    public INDArray getMaskArray() {
        return this.fwd.getMaskArray();
    }

    @Override
    public boolean isPretrainLayer() {
        return this.fwd.isPretrainLayer();
    }

    @Override
    public void clearNoiseWeightParams() {
        this.fwd.clearNoiseWeightParams();
        this.bwd.clearNoiseWeightParams();
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        Pair<INDArray, MaskState> ret = this.fwd.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
        this.bwd.feedForwardMaskArray(TimeSeriesUtils.reverseTimeSeriesMask(maskArray), currentMaskState, minibatchSize);
        return ret;
    }
}

