/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.NativeFunction;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFunction;
import org.tensorflow.internal.c_api.TF_Function;
import org.tensorflow.internal.c_api.TF_Operation;
import org.tensorflow.internal.c_api.TF_Output;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.internal.types.registry.TensorTypeRegistry;
import org.tensorflow.op.Ops;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.PartitionedCall;
import org.tensorflow.proto.framework.AttrValue;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.FunctionDef;
import org.tensorflow.proto.framework.OpDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.proto.framework.TensorShapeProto;
import org.tensorflow.types.family.TType;

public final class ConcreteFunction
implements AutoCloseable,
TensorFunction {
    private final Signature signature;
    private final NativeFunction nativeFunction;
    private final PointerScope scope;
    private final Set<TF_Function> dependencies;
    private final List<Class<? extends TType>> outputTypes;

    public static ConcreteFunction create(Function<Ops, Signature> functionBuilder) {
        try (Graph graph = new Graph();){
            Ops tf = Ops.create(graph);
            Signature signature = functionBuilder.apply(tf);
            ConcreteFunction concreteFunction = ConcreteFunction.buildFromGraph(graph, signature);
            return concreteFunction;
        }
    }

    public static ConcreteFunction create(Signature signature, Graph graph) {
        return ConcreteFunction.buildFromGraph(graph, signature);
    }

    public static ConcreteFunction create(Signature signature, Session session) {
        return ConcreteFunction.buildFromGraph(session.graph(), signature);
    }

    @Override
    public Signature signature() {
        return this.signature;
    }

    public String getDefinedName() {
        return this.nativeFunction.getName();
    }

    public FunctionDef getFunctionDef() {
        return this.nativeFunction.getFunctionDef();
    }

    public boolean isStateful() {
        return this.nativeFunction.isStateful();
    }

    Set<TF_Function> getDependencies() {
        return this.dependencies;
    }

    @Override
    public void close() {
        this.scope.close();
    }

    public String toString() {
        return this.signature.toString();
    }

    public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> arguments) {
        ArrayList inputList = new ArrayList(this.signature.inputNames().size());
        for (String inputName : this.signature.inputNames()) {
            if (!arguments.containsKey(inputName)) {
                throw new IllegalArgumentException("Function " + this.signature.methodName() + " has parameter \"" + inputName + "\", but no argument was passed for it.");
            }
            Operand<?> input = arguments.get(inputName);
            if (input == null) {
                throw new IllegalArgumentException("Can't pass null as an argument to a function.  Argument \"" + inputName + "\" was null.");
            }
            inputList.add(input);
        }
        List<Output<?>> outputList = PartitionedCall.create(scope, inputList, this.outputTypes, this, new PartitionedCall.Options[0]).output();
        if (this.signature.outputNames().size() == 0) {
            return Collections.emptyMap();
        }
        if (this.signature.outputNames().size() == 1) {
            return Collections.singletonMap(this.signature.outputNames().iterator().next(), outputList.get(0));
        }
        if (outputList.size() < this.signature.outputNames().size()) {
            throw new IllegalStateException("Somehow, not all required outputs were returned from the function(expected: " + this.signature.outputNames().size() + ", returned: " + outputList.size() + ")");
        }
        LinkedHashMap<String, Operand> namedOutputs = new LinkedHashMap<String, Operand>(this.signature.outputNames().size());
        Iterator<String> outputNames = this.signature.outputNames().iterator();
        int i = 0;
        while (outputNames.hasNext()) {
            String outputName = outputNames.next();
            Operand output = outputList.get(i);
            namedOutputs.put(outputName, output);
            ++i;
        }
        return Collections.unmodifiableMap(namedOutputs);
    }

    public Operand<?> call(Scope scope, Operand<?> argument) {
        SignatureDef signatureDef = this.signature.asSignatureDef();
        if (signatureDef.getInputsCount() != 1) {
            throw new IllegalArgumentException(String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName()));
        }
        String inputName = signatureDef.getInputsMap().keySet().iterator().next();
        if (signatureDef.getOutputsCount() != 1) {
            throw new IllegalArgumentException(String.format("Function [%s] has multiple outputs", signatureDef.getMethodName()));
        }
        String outputName = signatureDef.getOutputsMap().keySet().iterator().next();
        return this.call(scope, Collections.singletonMap(inputName, argument)).get(outputName);
    }

    @Override
    public Map<String, Tensor> call(Map<String, Tensor> arguments) {
        Ops tf = Ops.create();
        LinkedHashMap inputs = new LinkedHashMap(arguments.size());
        for (String inputName : arguments.keySet()) {
            Tensor argument = arguments.get(inputName);
            inputs.put(inputName, tf.constantOf((TType)argument));
        }
        Map<String, Operand<?>> outputs = tf.call(this, inputs);
        LinkedHashMap<String, Tensor> tensorOutputs = new LinkedHashMap<String, Tensor>(outputs.size());
        for (String outputName : outputs.keySet()) {
            tensorOutputs.put(outputName, (Tensor)outputs.get(outputName).asTensor());
        }
        return tensorOutputs;
    }

    public Map<String, Operand<?>> call(Ops tf, Map<String, Operand<?>> arguments) {
        return tf.call(this, arguments);
    }

    public Operand<?> call(Ops tf, Operand<?> argument) {
        return tf.call(this, argument);
    }

    TF_Function nativeHandle() {
        if (this.nativeFunction.getNativeHandle().isNull()) {
            throw new IllegalStateException("Function has been closed");
        }
        return this.nativeFunction.getNativeHandle();
    }

    ConcreteFunction(Signature signature, NativeFunction nativeFunction, Collection<NativeFunction> availableFunctions) {
        this(signature, nativeFunction, nativeFunction.getAllDependencies(availableFunctions));
    }

    static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction, Collection<NativeFunction> availableFunctions) {
        TensorInfo info;
        Signature.Builder builder = Signature.builder().methodName(nativeFunction.getFunctionDef().getSignature().getName()).key(nativeFunction.getName());
        for (OpDef.ArgDef input : nativeFunction.getFunctionDef().getSignature().getInputArgList()) {
            info = TensorInfo.newBuilder().setDtype(input.getType()).setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build()).setName(input.getName()).build();
            builder.input(input.getName(), info);
        }
        for (OpDef.ArgDef outputDef : nativeFunction.getFunctionDef().getSignature().getOutputArgList()) {
            info = TensorInfo.newBuilder().setDtype(outputDef.getType()).setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build()).setName(outputDef.getName()).build();
            builder.output(outputDef.getName(), info);
        }
        return new ConcreteFunction(builder.build(), nativeFunction, availableFunctions);
    }

    private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set<TF_Function> dependencies) {
        List<DataType> nativeOutputs;
        List<DataType> nativeInputs;
        this.signature = signature;
        this.nativeFunction = nativeFunction;
        this.dependencies = Collections.unmodifiableSet(dependencies);
        if (signature.getInputs().size() != nativeFunction.getFunctionDef().getSignature().getInputArgCount()) {
            throw new IllegalArgumentException("Signature must have the same number of inputs as the native function.  Expected " + nativeFunction.getFunctionDef().getSignature().getInputArgCount() + ", got " + this.signature.getInputs().size());
        }
        if (signature.getOutputs().size() != nativeFunction.getFunctionDef().getSignature().getOutputArgCount()) {
            throw new IllegalArgumentException("New signature must have the same number of outputs as the native function.  Expected " + nativeFunction.getFunctionDef().getSignature().getOutputArgCount() + ", got " + this.signature.getOutputs().size());
        }
        List<DataType> inputs = signature.getInputs().values().stream().map(x -> x.dataType).collect(Collectors.toList());
        if (!ConcreteFunction.dataTypesMatch(inputs, nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList().stream().map(OpDef.ArgDef::getType).collect(Collectors.toList()))) {
            throw new IllegalArgumentException("Data types of the signature's inputs must match the native function's (in order).  Expected " + nativeInputs + ", got " + inputs);
        }
        List<DataType> outputs = signature.getOutputs().values().stream().map(x -> x.dataType).collect(Collectors.toList());
        if (!ConcreteFunction.dataTypesMatch(outputs, nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream().map(OpDef.ArgDef::getType).collect(Collectors.toList()))) {
            throw new IllegalArgumentException("Data types of the signature's outputs must match the native function's (in order).  Expected " + nativeOutputs + ", got " + outputs);
        }
        this.outputTypes = outputs.stream().map(x -> TensorTypeRegistry.find(x).type()).collect(Collectors.toList());
        try (PointerScope scope = new PointerScope(new Class[0]);){
            this.scope = scope;
            scope.extend();
            scope.attach((Pointer)this.nativeFunction.getNativeHandle());
            this.dependencies.forEach(arg_0 -> ((PointerScope)scope).attach(arg_0));
        }
    }

    private void makeJit() {
        try (PointerScope scope = new PointerScope(new Class[0]);){
            byte[] bytes = AttrValue.newBuilder().setB(true).build().toByteArray();
            BytePointer trueValue = new BytePointer(bytes);
            TF_Status status1 = TF_Status.newStatus();
            tensorflow.TF_FunctionSetAttrValueProto(this.nativeHandle(), "_XlaMustCompile", (Pointer)trueValue, (long)bytes.length, status1);
            status1.throwExceptionIfNotOK();
            TF_Status status2 = TF_Status.newStatus();
            tensorflow.TF_FunctionSetAttrValueProto(this.nativeHandle(), "_noinline", (Pointer)trueValue, (long)bytes.length, status2);
            status2.throwExceptionIfNotOK();
        }
    }

    private static boolean dataTypesMatch(List<DataType> a, List<DataType> b) {
        if (a.size() != b.size()) {
            return false;
        }
        for (int i = 0; i < a.size(); ++i) {
            DataType aType = a.get(i);
            DataType bType = b.get(i);
            if (aType == DataType.DT_INVALID || bType == DataType.DT_INVALID || a.equals(b)) continue;
            return false;
        }
        return true;
    }

    private static TF_Operation outputHandle(Operand<?> operand) {
        if (operand == null) {
            throw new NullPointerException("Can't get output handle for null operand");
        }
        Pointer handle = operand.asOutput().getUnsafeNativeHandle();
        if (handle.isNull()) {
            throw new NullPointerException("Native handle of operand is null, has it been closed?");
        }
        if (!(handle instanceof TF_Operation)) {
            throw new IllegalArgumentException("Operand was not a graph operand");
        }
        return (TF_Operation)handle;
    }

    private static TF_Output resolveToOutput(Graph graph, List<Operand<?>> operands) {
        TF_Output handles = new TF_Output(operands.size());
        for (int i = 0; i < operands.size(); ++i) {
            Operand<?> input = operands.get(i);
            graph.checkInput(input);
            TF_Operation handle = ConcreteFunction.outputHandle(input);
            handles.position(i).oper(handle).index(input.asOutput().index());
        }
        handles.position(0L);
        return handles;
    }

    /*
     * Exception decompiling
     */
    private static ConcreteFunction buildFromGraph(Graph graph, Signature signature) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private static /* synthetic */ void lambda$buildFromGraph$6(List outputs, List inputs, GraphOperation x) {
        if (x.type().equals("Placeholder") || x.type().equals("PlaceholderWithDefault")) {
            throw new IllegalArgumentException("Can't calculate outputs (" + outputs + ") from inputs (" + inputs + "), they also depend on \"" + x + "\"");
        }
    }

    private static /* synthetic */ void lambda$buildFromGraph$5(List ops, Operand input) {
        ops.remove((GraphOperation)input.op());
    }

    private static /* synthetic */ Operand lambda$buildFromGraph$4(Graph graph, Map.Entry x) {
        return TensorFunction.validateDescription((Signature.TensorDescription)x.getValue(), graph, (String)x.getKey(), "Output");
    }

    private static /* synthetic */ Operand lambda$buildFromGraph$3(Graph graph, Map.Entry x) {
        return TensorFunction.validateDescription((Signature.TensorDescription)x.getValue(), graph, (String)x.getKey(), "Input");
    }
}

