/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.recurrent;

import java.util.Collection;
import lombok.NonNull;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.layers.recurrent.TimeDistributedLayer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty;

public class TimeDistributed
extends BaseWrapperLayer {
    private final int timeAxis;

    public TimeDistributed(@JsonProperty(value="underlying") @NonNull Layer underlying, @JsonProperty(value="timeAxis") int timeAxis) {
        super(underlying);
        if (underlying == null) {
            throw new NullPointerException("underlying is marked @NonNull but is null");
        }
        this.timeAxis = timeAxis;
    }

    @Override
    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
        NeuralNetConfiguration conf2 = conf.clone();
        conf2.setLayer(((TimeDistributed)conf2.getLayer()).getUnderlying());
        return new TimeDistributedLayer(this.underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, initializeParams, networkDataType), this.timeAxis);
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Only RNN input type is supported as input to TimeDistributed layer (layer #" + layerIndex + ")");
        }
        InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent)inputType;
        InputType ff = InputType.feedForward(rnn.getSize());
        InputType ffOut = this.underlying.getOutputType(layerIndex, ff);
        return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength());
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        if (inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Only RNN input type is supported as input to TimeDistributed layer");
        }
        InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent)inputType;
        InputType ff = InputType.feedForward(rnn.getSize());
        this.underlying.setNIn(ff, override);
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return null;
    }

    public int getTimeAxis() {
        return this.timeAxis;
    }

    @Override
    public String toString() {
        return "TimeDistributed(timeAxis=" + this.getTimeAxis() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof TimeDistributed)) {
            return false;
        }
        TimeDistributed other = (TimeDistributed)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        return this.getTimeAxis() == other.getTimeAxis();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof TimeDistributed;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        result = result * 59 + this.getTimeAxis();
        return result;
    }
}

