/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.AbstractLSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.LayerValidation;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.layers.recurrent.LSTMHelpers;
import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.ndarray.INDArray;

public class GravesLSTM
extends AbstractLSTM {
    private double forgetGateBiasInit;
    private IActivation gateActivationFn = new ActivationSigmoid();

    private GravesLSTM(Builder builder) {
        super(builder);
        this.forgetGateBiasInit = builder.forgetGateBiasInit;
        this.gateActivationFn = builder.gateActivationFn;
        this.initializeConstraints(builder);
    }

    @Override
    protected void initializeConstraints(Layer.Builder<?> builder) {
        super.initializeConstraints(builder);
        if (((Builder)builder).recurrentConstraints != null) {
            if (this.constraints == null) {
                this.constraints = new ArrayList();
            }
            for (LayerConstraint c : ((Builder)builder).recurrentConstraints) {
                LayerConstraint c2 = c.clone();
                c2.setParams(Collections.singleton("RW"));
                this.constraints.add(c2);
            }
        }
    }

    @Override
    public Layer instantiate(NeuralNetConfiguration conf, Collection<IterationListener> iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams) {
        LayerValidation.assertNInNOutSet("GravesLSTM", this.getLayerName(), layerIndex, this.getNIn(), this.getNOut());
        org.deeplearning4j.nn.layers.recurrent.GravesLSTM ret = new org.deeplearning4j.nn.layers.recurrent.GravesLSTM(conf);
        ret.setListeners(iterationListeners);
        ret.setIndex(layerIndex);
        ret.setParamsViewArray(layerParamsView);
        Map<String, INDArray> paramTable = this.initializer().init(conf, layerParamsView, initializeParams);
        ret.setParamTable(paramTable);
        ret.setConf(conf);
        return ret;
    }

    @Override
    public ParamInitializer initializer() {
        return GravesLSTMParamInitializer.getInstance();
    }

    @Override
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        return LSTMHelpers.getMemoryReport(this, inputType);
    }

    @Override
    public double getForgetGateBiasInit() {
        return this.forgetGateBiasInit;
    }

    @Override
    public IActivation getGateActivationFn() {
        return this.gateActivationFn;
    }

    @Override
    public void setForgetGateBiasInit(double forgetGateBiasInit) {
        this.forgetGateBiasInit = forgetGateBiasInit;
    }

    @Override
    public void setGateActivationFn(IActivation gateActivationFn) {
        this.gateActivationFn = gateActivationFn;
    }

    public GravesLSTM() {
    }

    @Override
    public String toString() {
        return "GravesLSTM(super=" + super.toString() + ", forgetGateBiasInit=" + this.getForgetGateBiasInit() + ", gateActivationFn=" + this.getGateActivationFn() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof GravesLSTM)) {
            return false;
        }
        GravesLSTM other = (GravesLSTM)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (Double.compare(this.getForgetGateBiasInit(), other.getForgetGateBiasInit()) != 0) {
            return false;
        }
        IActivation this$gateActivationFn = this.getGateActivationFn();
        IActivation other$gateActivationFn = other.getGateActivationFn();
        return !(this$gateActivationFn == null ? other$gateActivationFn != null : !this$gateActivationFn.equals(other$gateActivationFn));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof GravesLSTM;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        long $forgetGateBiasInit = Double.doubleToLongBits(this.getForgetGateBiasInit());
        result = result * 59 + (int)($forgetGateBiasInit >>> 32 ^ $forgetGateBiasInit);
        IActivation $gateActivationFn = this.getGateActivationFn();
        result = result * 59 + ($gateActivationFn == null ? 43 : $gateActivationFn.hashCode());
        return result;
    }

    public static class Builder
    extends AbstractLSTM.Builder<Builder> {
        @Override
        public GravesLSTM build() {
            return new GravesLSTM(this);
        }
    }
}

