/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.recurrent;

import java.util.Collection;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
import org.deeplearning4j.nn.params.BidirectionalParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;

@JsonIgnoreProperties(value={"initializer"})
public class Bidirectional
extends Layer {
    private Layer fwd;
    private Layer bwd;
    private Mode mode;
    private transient BidirectionalParamInitializer initializer;

    private Bidirectional(Builder builder) {
        super(builder);
    }

    public Bidirectional(@NonNull Layer layer) {
        this(Mode.CONCAT, layer);
        if (layer == null) {
            throw new NullPointerException("layer");
        }
    }

    public Bidirectional(@NonNull Mode mode, @NonNull Layer layer) {
        if (mode == null) {
            throw new NullPointerException("mode");
        }
        if (layer == null) {
            throw new NullPointerException("layer");
        }
        if (!(layer instanceof BaseRecurrentLayer) && !(layer instanceof BaseWrapperLayer)) {
            throw new IllegalArgumentException("Cannot wrap a non-recurrent layer: config must extend BaseRecurrentLayer. Got class: " + layer.getClass());
        }
        this.fwd = layer;
        this.bwd = layer.clone();
        this.mode = mode;
    }

    @Override
    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams) {
        NeuralNetConfiguration c1 = conf.clone();
        NeuralNetConfiguration c2 = conf.clone();
        c1.setLayer(this.fwd);
        c2.setLayer(this.bwd);
        int n = layerParamsView.length() / 2;
        INDArray fp = layerParamsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)n)});
        INDArray bp = layerParamsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)n, (int)(2 * n))});
        RecurrentLayer f = (RecurrentLayer)this.fwd.instantiate(c1, trainingListeners, layerIndex, fp, initializeParams);
        RecurrentLayer b = (RecurrentLayer)this.bwd.instantiate(c2, trainingListeners, layerIndex, bp, initializeParams);
        BidirectionalLayer ret = new BidirectionalLayer(conf, f, b);
        Map<String, INDArray> paramTable = this.initializer().init(conf, layerParamsView, initializeParams);
        ret.setParamTable(paramTable);
        ret.setConf(conf);
        return ret;
    }

    @Override
    public ParamInitializer initializer() {
        if (this.initializer == null) {
            this.initializer = new BidirectionalParamInitializer(this);
        }
        return this.initializer;
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        InputType outOrig = this.fwd.getOutputType(layerIndex, inputType);
        if (this.fwd instanceof LastTimeStep) {
            InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward)outOrig;
            if (this.mode == Mode.CONCAT) {
                return InputType.feedForward(2 * ff.getSize());
            }
            return ff;
        }
        InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent)outOrig;
        if (this.mode == Mode.CONCAT) {
            return InputType.recurrent(2 * r.getSize());
        }
        return r;
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        this.fwd.setNIn(inputType, override);
        this.bwd.setNIn(inputType, override);
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return this.fwd.getPreProcessorForInputType(inputType);
    }

    @Override
    public double getL1ByParam(String paramName) {
        return this.fwd.getL1ByParam(paramName.substring(1));
    }

    @Override
    public double getL2ByParam(String paramName) {
        return this.fwd.getL2ByParam(paramName.substring(1));
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        return this.fwd.isPretrainParam(paramName.substring(1));
    }

    @Override
    public IUpdater getUpdaterByParam(String paramName) {
        String sub = paramName.substring(1);
        return this.fwd.getUpdaterByParam(sub);
    }

    @Override
    public void setLayerName(String layerName) {
        this.layerName = layerName;
        this.fwd.setLayerName(layerName);
        this.bwd.setLayerName(layerName);
    }

    @Override
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        LayerMemoryReport lmr = this.fwd.getMemoryReport(inputType);
        lmr.scale(2);
        return lmr;
    }

    public Bidirectional() {
    }

    public Layer getFwd() {
        return this.fwd;
    }

    public Layer getBwd() {
        return this.bwd;
    }

    public Mode getMode() {
        return this.mode;
    }

    public BidirectionalParamInitializer getInitializer() {
        return this.initializer;
    }

    public void setFwd(Layer fwd) {
        this.fwd = fwd;
    }

    public void setBwd(Layer bwd) {
        this.bwd = bwd;
    }

    public void setMode(Mode mode) {
        this.mode = mode;
    }

    public void setInitializer(BidirectionalParamInitializer initializer) {
        this.initializer = initializer;
    }

    @Override
    public String toString() {
        return "Bidirectional(fwd=" + this.getFwd() + ", bwd=" + this.getBwd() + ", mode=" + (Object)((Object)this.getMode()) + ", initializer=" + this.getInitializer() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof Bidirectional)) {
            return false;
        }
        Bidirectional other = (Bidirectional)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        Layer this$fwd = this.getFwd();
        Layer other$fwd = other.getFwd();
        if (this$fwd == null ? other$fwd != null : !((Object)this$fwd).equals(other$fwd)) {
            return false;
        }
        Layer this$bwd = this.getBwd();
        Layer other$bwd = other.getBwd();
        if (this$bwd == null ? other$bwd != null : !((Object)this$bwd).equals(other$bwd)) {
            return false;
        }
        Mode this$mode = this.getMode();
        Mode other$mode = other.getMode();
        return !(this$mode == null ? other$mode != null : !((Object)((Object)this$mode)).equals((Object)other$mode));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof Bidirectional;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        Layer $fwd = this.getFwd();
        result = result * 59 + ($fwd == null ? 43 : ((Object)$fwd).hashCode());
        Layer $bwd = this.getBwd();
        result = result * 59 + ($bwd == null ? 43 : ((Object)$bwd).hashCode());
        Mode $mode = this.getMode();
        result = result * 59 + ($mode == null ? 43 : ((Object)((Object)$mode)).hashCode());
        return result;
    }

    public static class Builder
    extends Layer.Builder<Builder> {
        private Mode mode;
        private Layer layer;

        public Builder mode(Mode mode) {
            this.mode = mode;
            return this;
        }

        public Builder rnnLayer(Layer layer) {
            if (!(layer instanceof BaseRecurrentLayer) && !(layer instanceof BaseWrapperLayer)) {
                throw new IllegalArgumentException("Cannot wrap a non-recurrent layer: config must extend BaseRecurrentLayer. Got class: " + layer.getClass());
            }
            this.layer = layer;
            return this;
        }

        @Override
        public Bidirectional build() {
            return new Bidirectional(this);
        }

        public Builder(Mode mode, Layer layer) {
            this.mode = mode;
            this.layer = layer;
        }
    }

    public static enum Mode {
        ADD,
        MUL,
        AVERAGE,
        CONCAT;

    }
}

