/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.graph.vertex.impl;

import java.util.Arrays;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

public class LayerVertex
extends BaseGraphVertex {
    private Layer layer;
    private final InputPreProcessor layerPreProcessor;
    private boolean setLayerInput;

    public LayerVertex(ComputationGraph graph, String name, int vertexIndex, Layer layer, InputPreProcessor layerPreProcessor, boolean outputVertex) {
        this(graph, name, vertexIndex, null, null, layer, layerPreProcessor, outputVertex);
    }

    public LayerVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices, Layer layer, InputPreProcessor layerPreProcessor, boolean outputVertex) {
        super(graph, name, vertexIndex, inputVertices, outputVertices);
        this.graph = graph;
        this.vertexName = name;
        this.vertexIndex = vertexIndex;
        this.inputVertices = inputVertices;
        this.outputVertices = outputVertices;
        this.layer = layer;
        this.layerPreProcessor = layerPreProcessor;
        this.outputVertex = outputVertex;
        this.inputs = new INDArray[inputVertices != null ? inputVertices.length : 0];
    }

    @Override
    public boolean hasLayer() {
        return true;
    }

    @Override
    public void setLayerAsFrozen() {
        if (this.layer instanceof FrozenLayer) {
            return;
        }
        this.layer = new FrozenLayer(this.layer);
        this.layer.conf().getLayer().setLayerName(this.vertexName);
    }

    @Override
    public boolean isOutputVertex() {
        return this.outputVertex || this.layer instanceof BaseOutputLayer;
    }

    @Override
    public Layer getLayer() {
        return this.layer;
    }

    @Override
    public INDArray doForward(boolean training) {
        if (!this.canDoForward()) {
            throw new IllegalStateException("Cannot do forward pass: all inputs not set");
        }
        return this.layer.activate(training);
    }

    protected void applyPreprocessorAndSetInput() {
        INDArray currInput = this.inputs[0];
        if (this.layerPreProcessor != null) {
            if (Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive("LOOP_EXTERNAL") && Nd4j.getMemoryManager().getCurrentWorkspace() != Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_EXTERNAL")) {
                try (MemoryWorkspace wsB = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("LOOP_EXTERNAL").notifyScopeBorrowed();){
                    currInput = this.layerPreProcessor.preProcess(currInput, this.graph.batchSize());
                }
            } else {
                currInput = this.layerPreProcessor.preProcess(currInput, this.graph.batchSize());
            }
        }
        this.layer.setInput(currInput);
        this.setLayerInput = true;
    }

    @Override
    public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
        if (!this.canDoBackward()) {
            if (this.inputs == null || this.inputs[0] == null) {
                throw new IllegalStateException("Cannot do backward pass: inputs not set. Layer " + this.vertexName + " (idx " + this.vertexIndex + ") numInputs " + this.getNumInputArrays());
            }
            throw new IllegalStateException("Cannot do backward pass: all epsilons not set. Layer " + this.vertexName + " (idx " + this.vertexIndex + ") numInputs " + this.getNumInputArrays() + "; numOutputs " + this.getNumOutputConnections());
        }
        if (!this.setLayerInput) {
            this.applyPreprocessorAndSetInput();
        }
        Pair<Gradient, INDArray> pair = tbptt && this.layer instanceof RecurrentLayer ? ((RecurrentLayer)this.layer).tbpttBackpropGradient(this.epsilon, this.graph.getConfiguration().getTbpttBackLength()) : this.layer.backpropGradient(this.epsilon);
        if (this.layerPreProcessor != null) {
            INDArray eps = (INDArray)pair.getSecond();
            eps = this.layerPreProcessor.backprop(eps, this.graph.batchSize());
            pair.setSecond((Object)eps);
        }
        return new Pair(pair.getFirst(), (Object)new INDArray[]{(INDArray)pair.getSecond()});
    }

    @Override
    public void setInput(int inputNumber, INDArray input) {
        if (inputNumber > 0) {
            throw new IllegalArgumentException("Invalid input number: LayerVertex instances have only 1 input (got inputNumber = " + inputNumber + ")");
        }
        this.inputs[inputNumber] = input;
        this.setLayerInput = false;
        this.applyPreprocessorAndSetInput();
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
        this.layer.setBackpropGradientsViewArray(backpropGradientsViewArray);
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize) {
        if (maskArrays == null || maskArrays.length == 0) {
            return new Pair(null, (Object)currentMaskState);
        }
        if (this.layerPreProcessor != null) {
            Pair<INDArray, MaskState> pair = this.layerPreProcessor.feedForwardMaskArray(maskArrays[0], currentMaskState, minibatchSize);
            if (pair == null) {
                maskArrays[0] = null;
                currentMaskState = null;
            } else {
                maskArrays[0] = (INDArray)pair.getFirst();
                currentMaskState = (MaskState)((Object)pair.getSecond());
            }
        }
        return this.layer.feedForwardMaskArray(maskArrays[0], currentMaskState, minibatchSize);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("LayerVertex(id=").append(this.vertexIndex).append(",name=\"").append(this.vertexName).append("\",inputs=").append(Arrays.toString(this.inputVertices)).append(",outputs=").append(Arrays.toString(this.outputVertices)).append(")");
        return sb.toString();
    }

    @Override
    public boolean canDoBackward() {
        if (!this.isOutputVertex()) {
            if (this.getLayer() instanceof FrozenLayer) {
                return true;
            }
            return super.canDoBackward();
        }
        for (INDArray input : this.inputs) {
            if (input != null) continue;
            return false;
        }
        return this.layer instanceof IOutputLayer || this.epsilon != null;
    }

    public double computeScore(double l1, double l2, boolean training) {
        if (!(this.layer instanceof IOutputLayer)) {
            throw new UnsupportedOperationException("Cannot compute score: layer is not an output layer (layer class: " + this.layer.getClass().getSimpleName());
        }
        if (!this.setLayerInput) {
            this.applyPreprocessorAndSetInput();
        }
        IOutputLayer ol = (IOutputLayer)this.layer;
        return ol.computeScore(l1, l2, training);
    }

    public INDArray computeScoreForExamples(double l1, double l2) {
        if (!(this.layer instanceof IOutputLayer)) {
            throw new UnsupportedOperationException("Cannot compute score: layer is not an output layer (layer class: " + this.layer.getClass().getSimpleName());
        }
        if (!this.setLayerInput) {
            this.applyPreprocessorAndSetInput();
        }
        IOutputLayer ol = (IOutputLayer)this.layer;
        return ol.computeScoreForExamples(l1, l2);
    }

    @Override
    public void migrateInput() {
        this.layer.migrateInput();
    }

    public InputPreProcessor getLayerPreProcessor() {
        return this.layerPreProcessor;
    }

    public boolean isSetLayerInput() {
        return this.setLayerInput;
    }

    public void setLayer(Layer layer) {
        this.layer = layer;
    }

    public void setSetLayerInput(boolean setLayerInput) {
        this.setLayerInput = setLayerInput;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LayerVertex)) {
            return false;
        }
        LayerVertex other = (LayerVertex)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        Layer this$layer = this.getLayer();
        Layer other$layer = other.getLayer();
        if (this$layer == null ? other$layer != null : !this$layer.equals(other$layer)) {
            return false;
        }
        InputPreProcessor this$layerPreProcessor = this.getLayerPreProcessor();
        InputPreProcessor other$layerPreProcessor = other.getLayerPreProcessor();
        if (this$layerPreProcessor == null ? other$layerPreProcessor != null : !this$layerPreProcessor.equals(other$layerPreProcessor)) {
            return false;
        }
        return this.isSetLayerInput() == other.isSetLayerInput();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof LayerVertex;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        Layer $layer = this.getLayer();
        result = result * 59 + ($layer == null ? 43 : $layer.hashCode());
        InputPreProcessor $layerPreProcessor = this.getLayerPreProcessor();
        result = result * 59 + ($layerPreProcessor == null ? 43 : $layerPreProcessor.hashCode());
        result = result * 59 + (this.isSetLayerInput() ? 79 : 97);
        return result;
    }
}

