/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.custom;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.config.ExecutionResult;
import org.nd4j.autodiff.samediff.config.SDValue;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;

public class Invoke
extends DynamicCustomOp {
    private String functionName;
    private String[] inputVarNames;
    private String[] outputVarNames;
    private String[] subGraphInputVarNames;
    private String[] subGraphOutputVarNames;

    public Invoke() {
    }

    public Invoke(SameDiff sameDiff, InvokeParams invokeParams) {
        super(sameDiff, invokeParams.inputs);
        this.sameDiff = sameDiff;
        this.outputVarNames = invokeParams.outputVarNames;
        this.functionName = invokeParams.functionName;
        this.inputVarNames = invokeParams.inputVarNames;
        this.subGraphInputVarNames = invokeParams.subGraphInputVarNames;
        this.subGraphOutputVarNames = invokeParams.subGraphOutputVarNames;
    }

    public static ExecutionResult doInvoke(DifferentialFunction op, Map<String, INDArray> placeHolders, Map<String, SDValue> valuePlaceHolders) {
        String[] subGraphOutputNames;
        String[] outputVarNameMappings;
        Invoke invoke = (Invoke)op;
        String funcName = invoke.getFunctionName();
        SameDiff instance = op.getSameDiff().getFunction(funcName);
        instance.setEnableCache(false);
        SDVariable[] args = op.args();
        String[] inputVarNameMappings = invoke.getInputVarNames();
        String[] subGraphInputNames = invoke.subGraphInputVarNames;
        if (subGraphInputNames == null) {
            subGraphInputNames = inputVarNameMappings;
        }
        SDVariable[] outputs = op.outputVariables();
        if (inputVarNameMappings == null) {
            inputVarNameMappings = new String[args.length];
            for (int i = 0; i < inputVarNameMappings.length; ++i) {
                inputVarNameMappings[i] = args[i].name();
            }
        }
        if ((outputVarNameMappings = invoke.getOutputVarNames()) == null) {
            outputVarNameMappings = new String[outputs.length];
            for (int i = 0; i < outputs.length; ++i) {
                outputVarNameMappings[i] = outputs[i].name();
            }
        }
        if ((subGraphOutputNames = invoke.subGraphOutputVarNames) == null) {
            subGraphOutputNames = outputVarNameMappings;
        }
        List<String> relevantOutputNames = Arrays.asList(subGraphOutputNames);
        if (valuePlaceHolders.isEmpty()) {
            INDArray[] retOutput = new INDArray[subGraphOutputNames.length];
            LinkedHashMap<String, INDArray> inputMap = new LinkedHashMap<String, INDArray>();
            for (int i = 0; i < inputVarNameMappings.length; ++i) {
                inputMap.put(subGraphInputNames[i], placeHolders.get(op.argNames()[i]));
            }
            Map<String, INDArray> output = instance.output(inputMap, relevantOutputNames);
            int numAdded = 0;
            for (Map.Entry<String, INDArray> result : output.entrySet()) {
                if (!relevantOutputNames.contains(result.getKey())) continue;
                retOutput[numAdded] = output.get(result.getKey());
                ++numAdded;
            }
            return ExecutionResult.builder().outputs(ExecutionResult.pack(output)).build();
        }
        LinkedHashMap<String, SDValue> valueInputs = new LinkedHashMap<String, SDValue>();
        for (int i = 0; i < inputVarNameMappings.length; ++i) {
            valueInputs.put(subGraphInputNames[i], valuePlaceHolders.get(op.argNames()[i]));
        }
        Map<String, SDValue> valueOutputs = instance.outputValues(valueInputs, relevantOutputNames);
        LinkedHashMap<String, SDValue> result = new LinkedHashMap<String, SDValue>();
        for (int i = 0; i < outputVarNameMappings.length; ++i) {
            result.put(outputs[i].name(), valueOutputs.get(subGraphOutputNames[i]));
        }
        return ExecutionResult.builder().valueOutputs(result).build();
    }

    @Override
    public SDVariable[] outputVariables() {
        if (this.outputVariables == null) {
            int i;
            SameDiff func = this.sameDiff.getFunction(this.functionName);
            if (func == null) {
                throw new IllegalArgumentException("Unable to determine output data types for variables. No function of " + this.functionName + " found!");
            }
            if (this.subGraphOutputVarNames == null) {
                throw new IllegalStateException("Invalid InvokeConfiguration found. Please specify sub graph output names.");
            }
            SDVariable[] outputs = new SDVariable[this.subGraphOutputVarNames.length];
            block4: for (i = 0; i < this.subGraphOutputVarNames.length; ++i) {
                String subGraphVarName = this.subGraphOutputVarNames[i];
                SDVariable variable = func.getVariable(subGraphVarName);
                if (variable == null) {
                    throw new IllegalStateException("No variable found in sub graph named " + subGraphVarName);
                }
                switch (variable.getVariableType()) {
                    case VARIABLE: 
                    case ARRAY: 
                    case PLACEHOLDER: 
                    case SEQUENCE: {
                        SDVariable clone2;
                        if (variable.getShape() != null) {
                            clone2 = this.sameDiff.var(subGraphVarName + "_" + this.functionName, variable.dataType(), variable.getShape());
                            clone2.setVariableType(VariableType.ARRAY);
                            outputs[i] = clone2;
                            continue block4;
                        }
                        clone2 = this.sameDiff.var(subGraphVarName + "_" + this.functionName, variable.dataType(), new int[0]);
                        clone2.setVariableType(VariableType.ARRAY);
                        outputs[i] = clone2;
                        continue block4;
                    }
                    case CONSTANT: {
                        SDVariable clone2 = this.sameDiff.var(subGraphVarName + "_" + this.functionName, variable.dataType(), new int[0]);
                        clone2.setVariableType(VariableType.ARRAY);
                        outputs[i] = clone2;
                    }
                }
            }
            this.outputVariables = outputs;
            if (this.outputVarNames != null && this.outputVarNames.length == outputs.length) {
                for (i = 0; i < outputs.length; ++i) {
                    if (outputs[i].name().equals(this.outputVarNames[i])) continue;
                    this.sameDiff.updateVariableNameAndReference(outputs[i], this.outputVarNames[i], true);
                }
            } else if (this.outputVariables == null) {
                throw new IllegalArgumentException("Invalid configuration for output variable names. Must be equal to the number of outputs.");
            }
            this.addOutputsToOp();
            return outputs;
        }
        return this.outputVariables;
    }

    @Override
    public int getNumOutputs() {
        if (this.subGraphOutputVarNames != null) {
            return this.subGraphOutputVarNames.length;
        }
        if (this.outputVarNames != null) {
            return this.outputVarNames.length;
        }
        return 1;
    }

    @Override
    public String opName() {
        return "invoke";
    }

    @Override
    public void configureFromArguments() {
        super.configureFromArguments();
    }

    @Override
    public void configureWithSameDiff(SameDiff sameDiff) {
        super.configureWithSameDiff(sameDiff);
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        ArrayList<DataType> ret = new ArrayList<DataType>();
        for (int i = 0; i < this.getNumOutputs(); ++i) {
            ret.add(DataType.FLOAT);
        }
        return ret;
    }

    @Override
    public List<LongShapeDescriptor> calculateOutputShape() {
        return Collections.emptyList();
    }

    @Override
    public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
        ArrayList<LongShapeDescriptor> ret = new ArrayList<LongShapeDescriptor>();
        for (int i = 0; i < this.getNumOutputs(); ++i) {
            ret.add(LongShapeDescriptor.fromShape(new int[]{1}, DataType.DOUBLE));
        }
        return ret;
    }

    public String getFunctionName() {
        return this.functionName;
    }

    public String[] getInputVarNames() {
        return this.inputVarNames;
    }

    public String[] getOutputVarNames() {
        return this.outputVarNames;
    }

    public String[] getSubGraphInputVarNames() {
        return this.subGraphInputVarNames;
    }

    public String[] getSubGraphOutputVarNames() {
        return this.subGraphOutputVarNames;
    }

    public static class InvokeParams {
        private String functionName;
        private SDVariable[] inputs;
        private String[] inputVarNames;
        private String[] outputVarNames;
        private String[] subGraphInputVarNames;
        private String[] subGraphOutputVarNames;

        InvokeParams(String functionName, SDVariable[] inputs, String[] inputVarNames, String[] outputVarNames, String[] subGraphInputVarNames, String[] subGraphOutputVarNames) {
            this.functionName = functionName;
            this.inputs = inputs;
            this.inputVarNames = inputVarNames;
            this.outputVarNames = outputVarNames;
            this.subGraphInputVarNames = subGraphInputVarNames;
            this.subGraphOutputVarNames = subGraphOutputVarNames;
        }

        public static InvokeParamsBuilder builder() {
            return new InvokeParamsBuilder();
        }

        public String getFunctionName() {
            return this.functionName;
        }

        public SDVariable[] getInputs() {
            return this.inputs;
        }

        public String[] getInputVarNames() {
            return this.inputVarNames;
        }

        public String[] getOutputVarNames() {
            return this.outputVarNames;
        }

        public String[] getSubGraphInputVarNames() {
            return this.subGraphInputVarNames;
        }

        public String[] getSubGraphOutputVarNames() {
            return this.subGraphOutputVarNames;
        }

        public void setFunctionName(String functionName) {
            this.functionName = functionName;
        }

        public void setInputs(SDVariable[] inputs) {
            this.inputs = inputs;
        }

        public void setInputVarNames(String[] inputVarNames) {
            this.inputVarNames = inputVarNames;
        }

        public void setOutputVarNames(String[] outputVarNames) {
            this.outputVarNames = outputVarNames;
        }

        public void setSubGraphInputVarNames(String[] subGraphInputVarNames) {
            this.subGraphInputVarNames = subGraphInputVarNames;
        }

        public void setSubGraphOutputVarNames(String[] subGraphOutputVarNames) {
            this.subGraphOutputVarNames = subGraphOutputVarNames;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof InvokeParams)) {
                return false;
            }
            InvokeParams other = (InvokeParams)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$functionName = this.getFunctionName();
            String other$functionName = other.getFunctionName();
            if (this$functionName == null ? other$functionName != null : !this$functionName.equals(other$functionName)) {
                return false;
            }
            if (!Arrays.deepEquals(this.getInputs(), other.getInputs())) {
                return false;
            }
            if (!Arrays.deepEquals(this.getInputVarNames(), other.getInputVarNames())) {
                return false;
            }
            if (!Arrays.deepEquals(this.getOutputVarNames(), other.getOutputVarNames())) {
                return false;
            }
            if (!Arrays.deepEquals(this.getSubGraphInputVarNames(), other.getSubGraphInputVarNames())) {
                return false;
            }
            return Arrays.deepEquals(this.getSubGraphOutputVarNames(), other.getSubGraphOutputVarNames());
        }

        protected boolean canEqual(Object other) {
            return other instanceof InvokeParams;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $functionName = this.getFunctionName();
            result = result * 59 + ($functionName == null ? 43 : $functionName.hashCode());
            result = result * 59 + Arrays.deepHashCode(this.getInputs());
            result = result * 59 + Arrays.deepHashCode(this.getInputVarNames());
            result = result * 59 + Arrays.deepHashCode(this.getOutputVarNames());
            result = result * 59 + Arrays.deepHashCode(this.getSubGraphInputVarNames());
            result = result * 59 + Arrays.deepHashCode(this.getSubGraphOutputVarNames());
            return result;
        }

        public String toString() {
            return "Invoke.InvokeParams(functionName=" + this.getFunctionName() + ", inputs=" + Arrays.deepToString(this.getInputs()) + ", inputVarNames=" + Arrays.deepToString(this.getInputVarNames()) + ", outputVarNames=" + Arrays.deepToString(this.getOutputVarNames()) + ", subGraphInputVarNames=" + Arrays.deepToString(this.getSubGraphInputVarNames()) + ", subGraphOutputVarNames=" + Arrays.deepToString(this.getSubGraphOutputVarNames()) + ")";
        }

        public static class InvokeParamsBuilder {
            private String functionName;
            private SDVariable[] inputs;
            private String[] inputVarNames;
            private String[] outputVarNames;
            private String[] subGraphInputVarNames;
            private String[] subGraphOutputVarNames;

            InvokeParamsBuilder() {
            }

            public InvokeParamsBuilder functionName(String functionName) {
                this.functionName = functionName;
                return this;
            }

            public InvokeParamsBuilder inputs(SDVariable[] inputs) {
                this.inputs = inputs;
                return this;
            }

            public InvokeParamsBuilder inputVarNames(String[] inputVarNames) {
                this.inputVarNames = inputVarNames;
                return this;
            }

            public InvokeParamsBuilder outputVarNames(String[] outputVarNames) {
                this.outputVarNames = outputVarNames;
                return this;
            }

            public InvokeParamsBuilder subGraphInputVarNames(String[] subGraphInputVarNames) {
                this.subGraphInputVarNames = subGraphInputVarNames;
                return this;
            }

            public InvokeParamsBuilder subGraphOutputVarNames(String[] subGraphOutputVarNames) {
                this.subGraphOutputVarNames = subGraphOutputVarNames;
                return this;
            }

            public InvokeParams build() {
                return new InvokeParams(this.functionName, this.inputs, this.inputVarNames, this.outputVarNames, this.subGraphInputVarNames, this.subGraphOutputVarNames);
            }

            public String toString() {
                return "Invoke.InvokeParams.InvokeParamsBuilder(functionName=" + this.functionName + ", inputs=" + Arrays.deepToString(this.inputs) + ", inputVarNames=" + Arrays.deepToString(this.inputVarNames) + ", outputVarNames=" + Arrays.deepToString(this.outputVarNames) + ", subGraphInputVarNames=" + Arrays.deepToString(this.subGraphInputVarNames) + ", subGraphOutputVarNames=" + Arrays.deepToString(this.subGraphOutputVarNames) + ")";
            }
        }
    }
}

