/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.graph.vertex.impl;

import java.util.Arrays;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer;
import org.nd4j.linalg.api.ndarray.INDArray;

public class LayerVertex
extends BaseGraphVertex {
    private Layer layer;
    private InputPreProcessor layerPreProcessor;

    public LayerVertex(ComputationGraph graph, String name, int vertexIndex, Layer layer, InputPreProcessor layerPreProcessor) {
        this(graph, name, vertexIndex, null, null, layer, layerPreProcessor);
    }

    public LayerVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices, Layer layer, InputPreProcessor layerPreProcessor) {
        super(graph, name, vertexIndex, inputVertices, outputVertices);
        this.graph = graph;
        this.vertexName = name;
        this.vertexIndex = vertexIndex;
        this.inputVertices = inputVertices;
        this.outputVertices = outputVertices;
        this.layer = layer;
        this.layerPreProcessor = layerPreProcessor;
        this.inputs = new INDArray[inputVertices != null ? inputVertices.length : 0];
        this.epsilons = new INDArray[outputVertices != null ? outputVertices.length : 0];
    }

    @Override
    public boolean hasLayer() {
        return true;
    }

    @Override
    public boolean isOutputVertex() {
        return this.layer instanceof BaseOutputLayer;
    }

    @Override
    public Layer getLayer() {
        return this.layer;
    }

    @Override
    public INDArray doForward(boolean training) {
        if (!this.canDoForward()) {
            throw new IllegalStateException("Cannot do forward pass: all inputs not set");
        }
        INDArray currInput = this.inputs[0];
        if (this.layerPreProcessor != null) {
            currInput = this.layerPreProcessor.preProcess(currInput, this.graph.batchSize());
        }
        return this.layer.activate(currInput, training);
    }

    @Override
    public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
        if (!this.canDoBackward()) {
            throw new IllegalStateException("Cannot do backward pass: all epsilons not set");
        }
        INDArray epsTotal = null;
        if (this.epsilons != null && this.epsilons.length == 1) {
            epsTotal = this.epsilons[0];
        } else if (this.epsilons != null && this.epsilons.length > 1) {
            epsTotal = this.epsilons[0].dup();
            for (int i = 1; i < this.epsilons.length; ++i) {
                epsTotal.addi(this.epsilons[i]);
            }
        }
        Pair<Gradient, INDArray> pair = tbptt && this.layer instanceof BaseRecurrentLayer ? ((BaseRecurrentLayer)this.layer).tbpttBackpropGradient(epsTotal, this.graph.getConfiguration().getTbpttBackLength()) : this.layer.backpropGradient(epsTotal);
        if (this.layerPreProcessor != null) {
            INDArray eps = pair.getSecond();
            eps = this.layerPreProcessor.backprop(eps, this.graph.batchSize());
            pair.setSecond(eps);
        }
        return new Pair<Gradient, INDArray[]>(pair.getFirst(), new INDArray[]{pair.getSecond()});
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("LayerVertex(id=").append(this.vertexIndex).append(",name=\"").append(this.vertexName).append("\",inputs=").append(Arrays.toString(this.inputVertices)).append(",outputs=").append(Arrays.toString(this.outputVertices)).append(")");
        return sb.toString();
    }

    public InputPreProcessor getLayerPreProcessor() {
        return this.layerPreProcessor;
    }

    public void setLayer(Layer layer) {
        this.layer = layer;
    }

    public void setLayerPreProcessor(InputPreProcessor layerPreProcessor) {
        this.layerPreProcessor = layerPreProcessor;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LayerVertex)) {
            return false;
        }
        LayerVertex other = (LayerVertex)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Layer this$layer = this.getLayer();
        Layer other$layer = other.getLayer();
        if (this$layer == null ? other$layer != null : !this$layer.equals(other$layer)) {
            return false;
        }
        InputPreProcessor this$layerPreProcessor = this.getLayerPreProcessor();
        InputPreProcessor other$layerPreProcessor = other.getLayerPreProcessor();
        return !(this$layerPreProcessor == null ? other$layerPreProcessor != null : !this$layerPreProcessor.equals(other$layerPreProcessor));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof LayerVertex;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Layer $layer = this.getLayer();
        result = result * 59 + ($layer == null ? 0 : $layer.hashCode());
        InputPreProcessor $layerPreProcessor = this.getLayerPreProcessor();
        result = result * 59 + ($layerPreProcessor == null ? 0 : $layerPreProcessor.hashCode());
        return result;
    }
}

