/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.controlflow;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.SameDiffConditional;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.controlflow.IfDerivative;
import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.util.HashUtil;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class If
extends DifferentialFunction
implements CustomOp {
    private static final Logger log = LoggerFactory.getLogger(If.class);
    protected SameDiff loopBodyExecution;
    protected SameDiff predicateExecution;
    protected SameDiff falseBodyExecution;
    protected SameDiffConditional predicate;
    protected SameDiffFunctionDefinition trueBody;
    protected SameDiffFunctionDefinition falseBody;
    protected String blockName;
    protected String trueBodyName;
    protected String falseBodyName;
    protected SDVariable[] inputVars;
    protected Boolean trueBodyExecuted = null;
    protected SDVariable targetBoolean;
    protected SDVariable dummyResult;
    protected SDVariable[] outputVars;

    public If(If ifStatement) {
        this.sameDiff = ifStatement.sameDiff;
        this.outputVars = ifStatement.outputVars;
        this.falseBodyExecution = ifStatement.falseBodyExecution;
        this.trueBodyExecuted = ifStatement.trueBodyExecuted;
        this.falseBody = ifStatement.falseBody;
        this.trueBodyExecuted = ifStatement.trueBodyExecuted;
        this.dummyResult = ifStatement.dummyResult;
        this.inputVars = ifStatement.inputVars;
        this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(), new ZeroInitScheme(), DataType.FLOAT, 1L);
        if (this.sameDiff.getShapeForVarName(this.dummyResult.getVarName()) == null) {
            this.sameDiff.putShapeForVarName(this.dummyResult.getVarName(), new long[]{1L, 1L});
        }
    }

    public If(String blockName, SameDiff parent, SDVariable[] inputVars, SameDiffFunctionDefinition conditionBody, SameDiffConditional predicate, SameDiffFunctionDefinition trueBody, SameDiffFunctionDefinition falseBody) {
        String trueBodyName;
        this.sameDiff = parent;
        parent.putFunctionForId(this.getOwnName(), this);
        this.inputVars = inputVars;
        this.predicate = predicate;
        parent.addArgsFor(inputVars, (DifferentialFunction)this);
        this.trueBody = trueBody;
        this.falseBody = falseBody;
        this.blockName = blockName;
        this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(), new ZeroInitScheme('f'), DataType.FLOAT, 1L);
        parent.addOutgoingFor(new SDVariable[]{this.dummyResult}, (DifferentialFunction)this);
        SameDiff sameDiff = SameDiff.create();
        this.targetBoolean = predicate.eval(sameDiff, conditionBody, inputVars);
        this.predicateExecution = sameDiff;
        this.trueBodyName = trueBodyName = "true-body-" + UUID.randomUUID().toString();
        String falseBodyName = "false-body-" + UUID.randomUUID().toString();
        this.falseBodyName = trueBodyName;
        this.loopBodyExecution = parent.defineFunction(trueBodyName, trueBody, inputVars);
        this.falseBodyExecution = parent.defineFunction(falseBodyName, falseBody, inputVars);
        parent.defineFunction(blockName, conditionBody, inputVars);
        parent.putSubFunction("predicate-eval-body-" + UUID.randomUUID().toString(), sameDiff);
        this.loopBodyExecution = parent.getFunction(trueBodyName);
    }

    public void exectedTrueOrFalse(boolean trueBodyExecuted) {
        this.trueBodyExecuted = trueBodyExecuted ? Boolean.valueOf(true) : Boolean.valueOf(false);
    }

    @Override
    public SDVariable[] outputVariables(String baseName) {
        return new SDVariable[]{this.dummyResult};
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> f1) {
        ArrayList<SDVariable> ret = new ArrayList<SDVariable>();
        ret.addAll(Arrays.asList(new IfDerivative(this).outputVariables()));
        return ret;
    }

    @Override
    public String toString() {
        return this.opName();
    }

    @Override
    public String opName() {
        return "if";
    }

    @Override
    public long opHash() {
        return HashUtil.getLongHash(this.opName());
    }

    @Override
    public boolean isInplaceCall() {
        return false;
    }

    @Override
    public INDArray[] outputArguments() {
        return new INDArray[0];
    }

    @Override
    public INDArray[] inputArguments() {
        return new INDArray[0];
    }

    @Override
    public long[] iArgs() {
        return new long[0];
    }

    @Override
    public double[] tArgs() {
        return new double[0];
    }

    @Override
    public boolean[] bArgs() {
        return new boolean[0];
    }

    @Override
    public void addIArgument(int ... arg) {
    }

    @Override
    public void addIArgument(long ... arg) {
    }

    @Override
    public void addBArgument(boolean ... arg) {
    }

    @Override
    public void removeIArgument(Integer arg) {
    }

    @Override
    public Boolean getBArgument(int index) {
        return null;
    }

    @Override
    public Long getIArgument(int index) {
        return null;
    }

    @Override
    public int numIArguments() {
        return 0;
    }

    @Override
    public void addTArgument(double ... arg) {
    }

    @Override
    public void removeTArgument(Double arg) {
    }

    @Override
    public Double getTArgument(int index) {
        return null;
    }

    @Override
    public int numTArguments() {
        return 0;
    }

    @Override
    public int numBArguments() {
        return 0;
    }

    @Override
    public void addInputArgument(INDArray ... arg) {
    }

    @Override
    public void removeInputArgument(INDArray arg) {
    }

    @Override
    public INDArray getInputArgument(int index) {
        return null;
    }

    @Override
    public int numInputArguments() {
        return 0;
    }

    @Override
    public void addOutputArgument(INDArray ... arg) {
    }

    @Override
    public void removeOutputArgument(INDArray arg) {
    }

    @Override
    public INDArray getOutputArgument(int index) {
        return null;
    }

    @Override
    public int numOutputArguments() {
        return 0;
    }

    @Override
    public Op.Type opType() {
        return Op.Type.CONDITIONAL;
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        if (nodeDef.getName().contains("/cond/")) {
            return;
        }
        IfImportState ifNodes = TFGraphMapper.getInstance().nodesForIf(nodeDef, graph);
        GraphDef.Builder trueScopeGraphDefBuilder = GraphDef.newBuilder();
        for (NodeDef node : ifNodes.getTrueNodes()) {
            trueScopeGraphDefBuilder.addNode(node);
        }
        SameDiff trueScope = TFGraphMapper.getInstance().importGraph(trueScopeGraphDefBuilder.build());
        GraphDef.Builder falseScopeGraphDefBuilder = GraphDef.newBuilder();
        for (NodeDef node : ifNodes.getFalseNodes()) {
            falseScopeGraphDefBuilder.addNode(node);
        }
        SameDiff falseScope = TFGraphMapper.getInstance().importGraph(falseScopeGraphDefBuilder.build());
        GraphDef.Builder condScopeGraphDefBuilder = GraphDef.newBuilder();
        for (NodeDef node : ifNodes.getCondNodes()) {
            condScopeGraphDefBuilder.addNode(node);
        }
        SameDiff condScope = TFGraphMapper.getInstance().importGraph(condScopeGraphDefBuilder.build());
        initWith.putSubFunction(ifNodes.getTrueBodyScopeName(), trueScope);
        initWith.putSubFunction(ifNodes.getFalseBodyScopeName(), falseScope);
        initWith.putSubFunction(ifNodes.getConditionBodyScopeName(), condScope);
        this.loopBodyExecution = trueScope;
        this.falseBodyExecution = falseScope;
        this.predicateExecution = condScope;
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
    }

    @Override
    public List<LongShapeDescriptor> calculateOutputShape() {
        return Arrays.asList(LongShapeDescriptor.fromShape(new long[0], DataType.BOOL));
    }

    @Override
    public CustomOpDescriptor getDescriptor() {
        return null;
    }

    @Override
    public void assertValidForExecution() {
    }

    @Override
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + this.opName());
    }

    @Override
    public String tensorflowName() {
        throw new NoOpNameFoundException("This operation has no TF counterpart");
    }

    public static IfBuilder builder() {
        return new IfBuilder();
    }

    public If() {
    }

    public SameDiff getLoopBodyExecution() {
        return this.loopBodyExecution;
    }

    public SameDiff getPredicateExecution() {
        return this.predicateExecution;
    }

    public SameDiff getFalseBodyExecution() {
        return this.falseBodyExecution;
    }

    public SameDiffConditional getPredicate() {
        return this.predicate;
    }

    public SameDiffFunctionDefinition getTrueBody() {
        return this.trueBody;
    }

    public SameDiffFunctionDefinition getFalseBody() {
        return this.falseBody;
    }

    public String getBlockName() {
        return this.blockName;
    }

    public String getTrueBodyName() {
        return this.trueBodyName;
    }

    public String getFalseBodyName() {
        return this.falseBodyName;
    }

    public SDVariable[] getInputVars() {
        return this.inputVars;
    }

    public Boolean getTrueBodyExecuted() {
        return this.trueBodyExecuted;
    }

    public SDVariable getTargetBoolean() {
        return this.targetBoolean;
    }

    public SDVariable[] getOutputVars() {
        return this.outputVars;
    }

    public void setOutputVars(SDVariable[] outputVars) {
        this.outputVars = outputVars;
    }

    public static class IfBuilder {
        private String blockName;
        private SameDiff parent;
        private SDVariable[] inputVars;
        private SameDiffFunctionDefinition conditionBody;
        private SameDiffConditional predicate;
        private SameDiffFunctionDefinition trueBody;
        private SameDiffFunctionDefinition falseBody;

        IfBuilder() {
        }

        public IfBuilder blockName(String blockName) {
            this.blockName = blockName;
            return this;
        }

        public IfBuilder parent(SameDiff parent) {
            this.parent = parent;
            return this;
        }

        public IfBuilder inputVars(SDVariable[] inputVars) {
            this.inputVars = inputVars;
            return this;
        }

        public IfBuilder conditionBody(SameDiffFunctionDefinition conditionBody) {
            this.conditionBody = conditionBody;
            return this;
        }

        public IfBuilder predicate(SameDiffConditional predicate) {
            this.predicate = predicate;
            return this;
        }

        public IfBuilder trueBody(SameDiffFunctionDefinition trueBody) {
            this.trueBody = trueBody;
            return this;
        }

        public IfBuilder falseBody(SameDiffFunctionDefinition falseBody) {
            this.falseBody = falseBody;
            return this;
        }

        public If build() {
            return new If(this.blockName, this.parent, this.inputVars, this.conditionBody, this.predicate, this.trueBody, this.falseBody);
        }

        public String toString() {
            return "If.IfBuilder(blockName=" + this.blockName + ", parent=" + this.parent + ", inputVars=" + Arrays.deepToString(this.inputVars) + ", conditionBody=" + this.conditionBody + ", predicate=" + this.predicate + ", trueBody=" + this.trueBody + ", falseBody=" + this.falseBody + ")";
        }
    }
}

