/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.weights.WeightInit;

public abstract class BaseRecurrentLayer
extends FeedForwardLayer {
    protected WeightInit weightInitRecurrent;
    protected Distribution distRecurrent;

    protected BaseRecurrentLayer(Builder builder) {
        super(builder);
        this.weightInitRecurrent = builder.weightInitRecurrent;
        this.distRecurrent = builder.distRecurrent;
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Invalid input for RNN layer (layer index = " + layerIndex + ", layer name = \"" + this.getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType);
        }
        InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent)inputType;
        return InputType.recurrent(this.nOut, itr.getTimeSeriesLength());
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Invalid input for RNN layer (layer name = \"" + this.getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType);
        }
        if (this.nIn <= 0 || override) {
            InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent)inputType;
            this.nIn = r.getSize();
        }
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, this.getLayerName());
    }

    public WeightInit getWeightInitRecurrent() {
        return this.weightInitRecurrent;
    }

    public Distribution getDistRecurrent() {
        return this.distRecurrent;
    }

    public void setWeightInitRecurrent(WeightInit weightInitRecurrent) {
        this.weightInitRecurrent = weightInitRecurrent;
    }

    public void setDistRecurrent(Distribution distRecurrent) {
        this.distRecurrent = distRecurrent;
    }

    public BaseRecurrentLayer() {
    }

    @Override
    public String toString() {
        return "BaseRecurrentLayer(super=" + super.toString() + ", weightInitRecurrent=" + (Object)((Object)this.getWeightInitRecurrent()) + ", distRecurrent=" + this.getDistRecurrent() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BaseRecurrentLayer)) {
            return false;
        }
        BaseRecurrentLayer other = (BaseRecurrentLayer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        WeightInit this$weightInitRecurrent = this.getWeightInitRecurrent();
        WeightInit other$weightInitRecurrent = other.getWeightInitRecurrent();
        if (this$weightInitRecurrent == null ? other$weightInitRecurrent != null : !((Object)((Object)this$weightInitRecurrent)).equals((Object)other$weightInitRecurrent)) {
            return false;
        }
        Distribution this$distRecurrent = this.getDistRecurrent();
        Distribution other$distRecurrent = other.getDistRecurrent();
        return !(this$distRecurrent == null ? other$distRecurrent != null : !this$distRecurrent.equals(other$distRecurrent));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof BaseRecurrentLayer;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        WeightInit $weightInitRecurrent = this.getWeightInitRecurrent();
        result = result * 59 + ($weightInitRecurrent == null ? 43 : ((Object)((Object)$weightInitRecurrent)).hashCode());
        Distribution $distRecurrent = this.getDistRecurrent();
        result = result * 59 + ($distRecurrent == null ? 43 : $distRecurrent.hashCode());
        return result;
    }

    public static abstract class Builder<T extends Builder<T>>
    extends FeedForwardLayer.Builder<T> {
        protected List<LayerConstraint> recurrentConstraints;
        protected List<LayerConstraint> inputWeightConstraints;
        protected WeightInit weightInitRecurrent;
        protected Distribution distRecurrent;

        public T constrainRecurrent(LayerConstraint ... constraints) {
            this.recurrentConstraints = Arrays.asList(constraints);
            return (T)this;
        }

        public T constrainInputWeights(LayerConstraint ... constraints) {
            this.inputWeightConstraints = Arrays.asList(constraints);
            return (T)this;
        }

        public T weightInitRecurrent(WeightInit weightInit) {
            this.weightInitRecurrent = weightInit;
            return (T)this;
        }

        public T weightInitRecurrent(Distribution dist) {
            this.weightInitRecurrent = WeightInit.DISTRIBUTION;
            this.distRecurrent = dist;
            return (T)this;
        }
    }
}

