/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.graph.vertex.impl;

import java.util.Arrays;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;

public class ElementWiseVertex
extends BaseGraphVertex {
    private Op op;
    private int nInForwardPass;

    public ElementWiseVertex(ComputationGraph graph, String name, int vertexIndex, Op op, DataType dataType) {
        this(graph, name, vertexIndex, null, null, op, dataType);
    }

    public ElementWiseVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices, Op op, DataType dataType) {
        super(graph, name, vertexIndex, inputVertices, outputVertices, dataType);
        this.op = op;
    }

    @Override
    public boolean hasLayer() {
        return false;
    }

    @Override
    public Layer getLayer() {
        return null;
    }

    @Override
    public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
        long[] outShape;
        if (!this.canDoForward()) {
            throw new IllegalStateException("Cannot do forward pass: inputs not set");
        }
        this.nInForwardPass = this.inputs.length;
        if (this.inputs.length == 1) {
            return workspaceMgr.dup(ArrayType.ACTIVATIONS, this.inputs[0]);
        }
        boolean isBc = false;
        for (int i = 1; i < this.inputs.length; ++i) {
            if (this.inputs[0].equalShapes(this.inputs[i])) continue;
            isBc = true;
            break;
        }
        if (!isBc) {
            outShape = this.inputs[0].shape();
        } else {
            outShape = Shape.broadcastOutputShape((long[])this.inputs[0].shape(), (long[])this.inputs[1].shape());
            for (int i = 2; i < this.inputs.length; ++i) {
                outShape = Shape.broadcastOutputShape((long[])outShape, (long[])this.inputs[i].shape());
            }
        }
        switch (this.op) {
            case Add: {
                INDArray sum = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, this.dataType, outShape);
                if (isBc && !Arrays.equals(outShape, this.inputs[0].shape())) {
                    Nd4j.exec((CustomOp)new BroadcastTo(this.inputs[0], outShape, sum));
                } else {
                    sum.assign(this.inputs[0]);
                }
                for (int i = 1; i < this.inputs.length; ++i) {
                    sum.addi(this.inputs[i].castTo(this.dataType));
                }
                return sum;
            }
            case Average: {
                INDArray average = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, this.dataType, outShape);
                if (isBc && !Arrays.equals(outShape, this.inputs[0].shape())) {
                    Nd4j.exec((CustomOp)new BroadcastTo(this.inputs[0], outShape, average));
                } else {
                    average.assign(this.inputs[0]);
                }
                for (int i = 1; i < this.inputs.length; ++i) {
                    average.addi(this.inputs[i].castTo(this.dataType));
                }
                return average.divi((Number)this.inputs.length);
            }
            case Subtract: {
                if (this.inputs.length != 2) {
                    throw new IllegalArgumentException("ElementWise subtraction only supports 2 inputs");
                }
                return Nd4j.exec((CustomOp)new SubOp(this.inputs, new INDArray[]{workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, this.inputs[0].dataType(), outShape)}))[0];
            }
            case Product: {
                INDArray product = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, this.dataType, outShape);
                if (isBc && !Arrays.equals(outShape, this.inputs[0].shape())) {
                    Nd4j.exec((CustomOp)new BroadcastTo(this.inputs[0], outShape, product));
                } else {
                    product.assign(this.inputs[0]);
                }
                for (int i = 1; i < this.inputs.length; ++i) {
                    product.muli(this.inputs[i].castTo(this.dataType));
                }
                return product;
            }
            case Max: {
                boolean isBroadcast = false;
                for (int i = 1; i < this.inputs.length && !(isBroadcast |= !this.inputs[0].equalShapes(this.inputs[i])); ++i) {
                }
                if (!isBroadcast) {
                    INDArray max = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, this.inputs[0].dataType(), this.inputs[0].shape(), this.inputs[0].ordering());
                    DynamicCustomOp op = DynamicCustomOp.builder((String)"mergemax").addInputs(this.inputs).addOutputs(new INDArray[]{max}).callInplace(false).build();
                    Nd4j.getExecutioner().exec((CustomOp)op);
                    return max;
                }
                if (this.inputs.length == 1) {
                    return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, this.inputs[0]);
                }
                INDArray max = Transforms.max((INDArray)this.inputs[0], (INDArray)this.inputs[1], (boolean)true);
                for (int i = 2; i < this.inputs.length; ++i) {
                    max = Transforms.max((INDArray)max, (INDArray)this.inputs[i], (boolean)false);
                }
                return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, max);
            }
        }
        throw new UnsupportedOperationException("Unknown op: " + (Object)((Object)this.op));
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) {
        if (!this.canDoBackward()) {
            throw new IllegalStateException("Cannot do backward pass: errors not set");
        }
        if (this.nInForwardPass == 1) {
            return new Pair(null, (Object)new INDArray[]{workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon)});
        }
        boolean broadcastCase = false;
        for (int i = 1; i < this.nInForwardPass; ++i) {
            broadcastCase |= !this.inputs[0].equalShapes(this.inputs[i]);
        }
        switch (this.op) {
            case Add: {
                INDArray[] out = new INDArray[this.nInForwardPass];
                for (int i = 0; i < this.nInForwardPass; ++i) {
                    if (!broadcastCase) {
                        out[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                        continue;
                    }
                    if (this.inputs[i].equalShapes(this.epsilon)) {
                        out[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                        continue;
                    }
                    int[] bcDim = Shape.getBroadcastDimensions((long[])this.inputs[i].shape(), (long[])this.epsilon.shape());
                    try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);){
                        out[i] = this.epsilon.sum(true, bcDim);
                        continue;
                    }
                }
                return new Pair(null, (Object)out);
            }
            case Average: {
                INDArray[] outAverage = new INDArray[this.nInForwardPass];
                try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);){
                    for (int i = 0; i < this.nInForwardPass; ++i) {
                        if (this.inputs[i].equalShapes(this.epsilon)) {
                            outAverage[i] = this.epsilon.div((Number)this.nInForwardPass);
                            continue;
                        }
                        int[] bcDim = Shape.getBroadcastDimensions((long[])this.inputs[i].shape(), (long[])this.epsilon.shape());
                        outAverage[i] = this.epsilon.div((Number)this.nInForwardPass).sum(true, bcDim);
                    }
                    return new Pair(null, (Object)outAverage);
                }
            }
            case Subtract: {
                INDArray[] out2 = new INDArray[2];
                if (!broadcastCase) {
                    out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                    out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon).negi();
                    return new Pair(null, (Object)out2);
                }
                if (this.inputs[0].equalShapes(this.epsilon)) {
                    out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                    int[] bcDim = Shape.getBroadcastDimensions((long[])this.inputs[1].shape(), (long[])this.epsilon.shape());
                    try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);){
                        out2[1] = this.epsilon.sum(true, bcDim).negi();
                        return new Pair(null, (Object)out2);
                    }
                }
                int[] bcDim = Shape.getBroadcastDimensions((long[])this.inputs[0].shape(), (long[])this.epsilon.shape());
                try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);){
                    out2[0] = this.epsilon.sum(true, bcDim);
                }
                out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon).negi();
                return new Pair(null, (Object)out2);
            }
            case Product: {
                int i;
                INDArray[] out_product = new INDArray[this.nInForwardPass];
                INDArray[] inBc = this.inputs;
                if (broadcastCase) {
                    inBc = new INDArray[this.inputs.length];
                    for (i = 0; i < this.inputs.length; ++i) {
                        if (this.inputs[i].equalShapes(this.epsilon)) {
                            inBc[i] = this.inputs[i];
                            continue;
                        }
                        inBc[i] = this.epsilon.ulike();
                        Nd4j.exec((CustomOp)new BroadcastTo(this.inputs[i], this.epsilon.shape(), inBc[i]));
                    }
                }
                for (i = 0; i < this.nInForwardPass; ++i) {
                    out_product[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.epsilon);
                    for (int j = 0; j < this.nInForwardPass; ++j) {
                        if (i == j) continue;
                        out_product[i].muli(inBc[j]);
                    }
                    if (this.inputs[i].equalShapes(this.epsilon)) continue;
                    int[] bcDim = Shape.getBroadcastDimensions((long[])this.inputs[i].shape(), (long[])this.epsilon.shape());
                    try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);){
                        out_product[i] = out_product[i].sum(true, bcDim);
                        continue;
                    }
                }
                return new Pair(null, (Object)out_product);
            }
            case Max: {
                INDArray[] outMax = new INDArray[this.nInForwardPass];
                INDArray maxIndices = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, DataType.INT, this.epsilon.shape(), this.epsilon.ordering());
                INDArray[] bcIn = this.inputs;
                if (broadcastCase) {
                    bcIn = new INDArray[this.inputs.length];
                    for (int i = 0; i < this.inputs.length; ++i) {
                        if (this.inputs[i].equalShapes(this.epsilon)) {
                            bcIn[i] = this.inputs[i];
                            continue;
                        }
                        bcIn[i] = this.epsilon.ulike();
                        Nd4j.exec((CustomOp)new BroadcastTo(this.inputs[i], this.epsilon.shape(), bcIn[i]));
                    }
                }
                DynamicCustomOp op = DynamicCustomOp.builder((String)"mergemaxindex").addInputs(bcIn).addOutputs(new INDArray[]{maxIndices}).callInplace(false).build();
                Nd4j.getExecutioner().exec((CustomOp)op);
                for (int i = 0; i < this.nInForwardPass; ++i) {
                    outMax[i] = workspaceMgr.create(ArrayType.BP_WORKING_MEM, DataType.BOOL, maxIndices.shape());
                    MatchConditionTransform nd4jop = new MatchConditionTransform(maxIndices, outMax[i], Conditions.equals((Number)i));
                    Nd4j.getExecutioner().exec((org.nd4j.linalg.api.ops.Op)nd4jop);
                    if (broadcastCase && !this.epsilon.equalShapes(this.inputs[i])) {
                        outMax[i] = outMax[i].castTo(this.epsilon.dataType()).mul(this.epsilon);
                        int[] bcDim = Shape.getBroadcastDimensions((long[])this.inputs[i].shape(), (long[])this.epsilon.shape());
                        try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD);){
                            outMax[i] = outMax[i].sum(true, bcDim);
                            continue;
                        }
                    }
                    outMax[i] = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, outMax[i].castTo(this.epsilon.dataType()).muli(this.epsilon));
                }
                return new Pair(null, (Object)outMax);
            }
        }
        throw new UnsupportedOperationException("Unknown op: " + (Object)((Object)this.op));
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
        if (backpropGradientsViewArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize) {
        if (maskArrays == null) {
            return new Pair(null, (Object)currentMaskState);
        }
        for (INDArray arr : maskArrays) {
            if (arr != null) continue;
            return new Pair(null, (Object)currentMaskState);
        }
        if (maskArrays.length == 1) {
            return new Pair((Object)maskArrays[0], (Object)currentMaskState);
        }
        INDArray ret = Nd4j.createUninitialized((DataType)DataType.BOOL, (long[])maskArrays[0].shape());
        Nd4j.getExecutioner().exec((org.nd4j.linalg.api.ops.Op)new Or(maskArrays[0].castTo(DataType.BOOL), maskArrays[1].castTo(DataType.BOOL), ret));
        for (int i = 2; i < maskArrays.length; ++i) {
            Nd4j.getExecutioner().exec((org.nd4j.linalg.api.ops.Op)new Or(maskArrays[i].castTo(DataType.BOOL), ret, ret));
        }
        return new Pair((Object)ret.castTo(Nd4j.defaultFloatingPointType()), (Object)currentMaskState);
    }

    @Override
    public String toString() {
        return "ElementWiseVertex(id=" + this.getVertexIndex() + ",name=\"" + this.getVertexName() + "\",op=" + (Object)((Object)this.op) + ")";
    }

    public static enum Op {
        Add,
        Subtract,
        Product,
        Average,
        Max;

    }
}

