/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.graph;

import java.beans.ConstructorProperties;
import java.util.Arrays;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.convolution.KernelValidationUtil;
import org.deeplearning4j.nn.layers.factory.LayerFactories;

public class LayerVertex
extends GraphVertex {
    private NeuralNetConfiguration layerConf;
    private InputPreProcessor preProcessor;

    @Override
    public GraphVertex clone() {
        return new LayerVertex(this.layerConf.clone(), this.preProcessor != null ? this.preProcessor.clone() : null);
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof LayerVertex)) {
            return false;
        }
        LayerVertex lv = (LayerVertex)o;
        if (!this.layerConf.equals(lv.layerConf)) {
            return false;
        }
        if (this.preProcessor == null && lv.preProcessor != null || this.preProcessor != null && lv.preProcessor == null) {
            return false;
        }
        return this.preProcessor == null || this.preProcessor.equals(lv.preProcessor);
    }

    @Override
    public int hashCode() {
        return this.layerConf.hashCode() ^ (this.preProcessor != null ? this.preProcessor.hashCode() : 0);
    }

    @Override
    public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx) {
        return new org.deeplearning4j.nn.graph.vertex.impl.LayerVertex(graph, name, idx, (org.deeplearning4j.nn.api.Layer)LayerFactories.getFactory(this.layerConf).create(this.layerConf, null, idx), this.preProcessor);
    }

    @Override
    public InputType getOutputType(InputType ... vertexInputs) throws InvalidInputTypeException {
        if (vertexInputs.length != 1) {
            throw new InvalidInputTypeException("LayerVertex expects exactly one input. Got: " + Arrays.toString(vertexInputs));
        }
        Layer layer = this.layerConf.getLayer();
        if (layer instanceof ConvolutionLayer || layer instanceof SubsamplingLayer) {
            int[] padding;
            int[] stride;
            int[] kernel;
            int channelsOut;
            InputType.InputTypeConvolutional afterPreProcessor;
            if (this.preProcessor != null) {
                if (this.preProcessor instanceof FeedForwardToCnnPreProcessor) {
                    FeedForwardToCnnPreProcessor ffcnn = (FeedForwardToCnnPreProcessor)this.preProcessor;
                    afterPreProcessor = (InputType.InputTypeConvolutional)InputType.convolutional(ffcnn.getInputHeight(), ffcnn.getInputWidth(), ffcnn.getNumChannels());
                } else if (this.preProcessor instanceof RnnToCnnPreProcessor) {
                    RnnToCnnPreProcessor rnncnn = (RnnToCnnPreProcessor)this.preProcessor;
                    afterPreProcessor = (InputType.InputTypeConvolutional)InputType.convolutional(rnncnn.getInputHeight(), rnncnn.getInputWidth(), rnncnn.getNumChannels());
                } else {
                    afterPreProcessor = (InputType.InputTypeConvolutional)vertexInputs[0];
                }
            } else {
                afterPreProcessor = (InputType.InputTypeConvolutional)vertexInputs[0];
            }
            if (layer instanceof ConvolutionLayer) {
                channelsOut = ((ConvolutionLayer)layer).getNOut();
                kernel = ((ConvolutionLayer)layer).getKernelSize();
                stride = ((ConvolutionLayer)layer).getStride();
                padding = ((ConvolutionLayer)layer).getPadding();
            } else {
                channelsOut = afterPreProcessor.getDepth();
                kernel = ((SubsamplingLayer)layer).getKernelSize();
                stride = ((SubsamplingLayer)layer).getStride();
                padding = ((SubsamplingLayer)layer).getPadding();
            }
            int inHeight = afterPreProcessor.getHeight();
            int inWidth = afterPreProcessor.getWidth();
            new KernelValidationUtil();
            KernelValidationUtil.validateShapes(inHeight, inWidth, kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1]);
            int outWidth = (inWidth - kernel[1] + 2 * padding[1]) / stride[1] + 1;
            int outHeight = (inHeight - kernel[0] + 2 * padding[0]) / stride[0] + 1;
            return InputType.convolutional(outHeight, outWidth, channelsOut);
        }
        if (layer instanceof BaseRecurrentLayer) {
            return InputType.recurrent(((BaseRecurrentLayer)layer).getNOut());
        }
        if (layer instanceof FeedForwardLayer) {
            return InputType.feedForward(((FeedForwardLayer)layer).getNOut());
        }
        return vertexInputs[0];
    }

    @ConstructorProperties(value={"layerConf", "preProcessor"})
    public LayerVertex(NeuralNetConfiguration layerConf, InputPreProcessor preProcessor) {
        this.layerConf = layerConf;
        this.preProcessor = preProcessor;
    }

    public LayerVertex() {
    }

    public NeuralNetConfiguration getLayerConf() {
        return this.layerConf;
    }

    public InputPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public void setLayerConf(NeuralNetConfiguration layerConf) {
        this.layerConf = layerConf;
    }

    public void setPreProcessor(InputPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public String toString() {
        return "LayerVertex(layerConf=" + this.getLayerConf() + ", preProcessor=" + this.getPreProcessor() + ")";
    }
}

