/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;

public abstract class FeedForwardLayer
extends BaseLayer {
    protected long nIn;
    protected long nOut;

    public FeedForwardLayer(Builder builder) {
        super(builder);
        this.nIn = builder.nIn;
        this.nOut = builder.nOut;
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.FF && inputType.getType() != InputType.Type.CNNFlat) {
            throw new IllegalStateException("Invalid input type (layer index = " + layerIndex + ", layer name=\"" + this.getLayerName() + "\"): expected FeedForward input type. Got: " + inputType);
        }
        return InputType.feedForward(this.nOut);
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        if (inputType == null || inputType.getType() != InputType.Type.FF && inputType.getType() != InputType.Type.CNNFlat) {
            throw new IllegalStateException("Invalid input type (layer name=\"" + this.getLayerName() + "\"): expected FeedForward input type. Got: " + inputType);
        }
        if (this.nIn <= 0L || override) {
            if (inputType.getType() == InputType.Type.FF) {
                InputType.InputTypeFeedForward f = (InputType.InputTypeFeedForward)inputType;
                this.nIn = f.getSize();
            } else {
                InputType.InputTypeConvolutionalFlat f = (InputType.InputTypeConvolutionalFlat)inputType;
                this.nIn = f.getFlattenedSize();
            }
        }
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        if (inputType == null) {
            throw new IllegalStateException("Invalid input for layer (layer name = \"" + this.getLayerName() + "\"): input type is null");
        }
        switch (inputType.getType()) {
            case FF: 
            case CNNFlat: {
                return null;
            }
            case RNN: {
                return new RnnToFeedForwardPreProcessor();
            }
            case CNN: {
                InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType;
                return new CnnToFeedForwardPreProcessor(c.getHeight(), c.getWidth(), c.getChannels());
            }
            case CNN3D: {
                InputType.InputTypeConvolutional3D c3d = (InputType.InputTypeConvolutional3D)inputType;
                return new Cnn3DToFeedForwardPreProcessor(c3d.getDepth(), c3d.getHeight(), c3d.getWidth(), c3d.getChannels(), true);
            }
        }
        throw new RuntimeException("Unknown input type: " + inputType);
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        return false;
    }

    public long getNIn() {
        return this.nIn;
    }

    public long getNOut() {
        return this.nOut;
    }

    public void setNIn(long nIn) {
        this.nIn = nIn;
    }

    public void setNOut(long nOut) {
        this.nOut = nOut;
    }

    public FeedForwardLayer() {
    }

    @Override
    public String toString() {
        return "FeedForwardLayer(super=" + super.toString() + ", nIn=" + this.getNIn() + ", nOut=" + this.getNOut() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof FeedForwardLayer)) {
            return false;
        }
        FeedForwardLayer other = (FeedForwardLayer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.getNIn() != other.getNIn()) {
            return false;
        }
        return this.getNOut() == other.getNOut();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof FeedForwardLayer;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        long $nIn = this.getNIn();
        result = result * 59 + (int)($nIn >>> 32 ^ $nIn);
        long $nOut = this.getNOut();
        result = result * 59 + (int)($nOut >>> 32 ^ $nOut);
        return result;
    }

    public static abstract class Builder<T extends Builder<T>>
    extends BaseLayer.Builder<T> {
        protected long nIn = 0L;
        protected long nOut = 0L;

        public T nIn(int nIn) {
            this.setNIn(nIn);
            return (T)this;
        }

        public T nIn(long nIn) {
            this.setNIn(nIn);
            return (T)this;
        }

        public T nOut(int nOut) {
            this.setNOut(nOut);
            return (T)this;
        }

        public T nOut(long nOut) {
            this.setNOut((int)nOut);
            return (T)this;
        }

        public T units(int units) {
            return this.nOut(units);
        }

        public long getNIn() {
            return this.nIn;
        }

        public long getNOut() {
            return this.nOut;
        }

        public void setNIn(long nIn) {
            this.nIn = nIn;
        }

        public void setNOut(long nOut) {
            this.nOut = nOut;
        }
    }
}

