/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.onnx;

import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.onnx.TensorConverter;
import ai.vespa.rankingexpression.importer.onnx.TypeConverter;
import ai.vespa.rankingexpression.importer.operations.Argument;
import ai.vespa.rankingexpression.importer.operations.ConcatV2;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.Identity;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import ai.vespa.rankingexpression.importer.operations.Join;
import ai.vespa.rankingexpression.importer.operations.Map;
import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.NoOp;
import ai.vespa.rankingexpression.importer.operations.Reshape;
import ai.vespa.rankingexpression.importer.operations.Shape;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.functions.ScalarFunctions;
import java.util.List;
import java.util.stream.Collectors;
import onnx.Onnx;

class GraphImporter {
    GraphImporter() {
    }

    private static IntermediateOperation mapOperation(Onnx.NodeProto node, List<IntermediateOperation> inputs, IntermediateGraph graph) {
        String nodeName = node.getName();
        String modelName = graph.name();
        switch (node.getOpType().toLowerCase()) {
            case "abs": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
            }
            case "add": {
                return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
            }
            case "acos": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
            }
            case "asin": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
            }
            case "atan": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
            }
            case "ceil": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
            }
            case "concat": {
                return new ConcatV2(modelName, nodeName, inputs);
            }
            case "cos": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
            }
            case "div": {
                return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
            }
            case "elu": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
            }
            case "equal": {
                return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
            }
            case "exp": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
            }
            case "floor": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
            }
            case "greater": {
                return new Join(modelName, nodeName, inputs, ScalarFunctions.greater());
            }
            case "identity": {
                return new Identity(modelName, nodeName, inputs);
            }
            case "less": {
                return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
            }
            case "log": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
            }
            case "matmul": {
                return new MatMul(modelName, nodeName, inputs);
            }
            case "max": {
                return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
            }
            case "min": {
                return new Join(modelName, nodeName, inputs, ScalarFunctions.min());
            }
            case "mean": {
                return new Join(modelName, nodeName, inputs, ScalarFunctions.mean());
            }
            case "mul": {
                return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
            }
            case "neg": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
            }
            case "pow": {
                return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
            }
            case "reshape": {
                return new Reshape(modelName, nodeName, inputs);
            }
            case "reciprocal": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
            }
            case "relu": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
            }
            case "selu": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
            }
            case "shape": {
                return new Shape(modelName, nodeName, inputs);
            }
            case "sin": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
            }
            case "sqrt": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
            }
            case "sigmoid": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
            }
            case "sub": {
                return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
            }
            case "tan": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
            }
            case "tanh": {
                return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
            }
        }
        NoOp op = new NoOp(modelName, node.getName(), inputs);
        op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
        return op;
    }

    static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) {
        Onnx.GraphProto onnxGraph = model.getGraph();
        IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
        GraphImporter.importOperations(onnxGraph, intermediateGraph);
        GraphImporter.verifyOutputTypes(onnxGraph, intermediateGraph);
        return intermediateGraph;
    }

    private static void importOperations(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
        for (Onnx.ValueInfoProto valueInfo : onnxGraph.getOutputList()) {
            GraphImporter.importOperation(valueInfo.getName(), onnxGraph, intermediateGraph);
        }
    }

    private static IntermediateOperation importOperation(String name, Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
        IntermediateOperation operation;
        if (intermediateGraph.alreadyImported(name)) {
            return intermediateGraph.get(name);
        }
        if (GraphImporter.isArgumentTensor(name, onnxGraph)) {
            Onnx.ValueInfoProto valueInfoProto = GraphImporter.getArgumentTensor(name, onnxGraph);
            if (valueInfoProto == null) {
                throw new IllegalArgumentException("Could not find argument tensor '" + name + "'");
            }
            OrderedTensorType type2 = TypeConverter.typeFrom(valueInfoProto.getType());
            operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type2);
            intermediateGraph.inputs(intermediateGraph.defaultSignature()).put(IntermediateOperation.namePartOf(name), operation.vespaName());
        } else if (GraphImporter.isConstantTensor(name, onnxGraph)) {
            Onnx.TensorProto tensorProto = GraphImporter.getConstantTensor(name, onnxGraph);
            OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto);
            operation = new Constant(intermediateGraph.name(), name, defaultType);
            operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
        } else {
            Onnx.NodeProto node = GraphImporter.getNodeFromGraph(name, onnxGraph);
            List<IntermediateOperation> inputs = GraphImporter.importOperationInputs(node, onnxGraph, intermediateGraph);
            operation = GraphImporter.mapOperation(node, inputs, intermediateGraph);
            if (GraphImporter.isOutputNode(name, onnxGraph)) {
                intermediateGraph.outputs(intermediateGraph.defaultSignature()).put(IntermediateOperation.namePartOf(name), operation.vespaName());
            }
        }
        intermediateGraph.put(operation.vespaName(), operation);
        return operation;
    }

    private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
        Onnx.ValueInfoProto value = GraphImporter.getArgumentTensor(name, graph);
        Onnx.TensorProto tensor = GraphImporter.getConstantTensor(name, graph);
        return value != null && tensor == null;
    }

    private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
        Onnx.ValueInfoProto value = GraphImporter.getArgumentTensor(name, graph);
        Onnx.TensorProto tensor = GraphImporter.getConstantTensor(name, graph);
        return value != null && tensor != null;
    }

    private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
        for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
            if (!valueInfo.getName().equals(name)) continue;
            return valueInfo;
        }
        return null;
    }

    private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
        for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
            if (!tensorProto.getName().equals(name)) continue;
            return tensorProto;
        }
        return null;
    }

    private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
        return GraphImporter.getOutputNode(name, graph) != null;
    }

    private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
        for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
            if (valueInfo.getName().equals(name)) {
                return valueInfo;
            }
            String nodeName = IntermediateOperation.namePartOf(valueInfo.getName());
            if (!nodeName.equals(name)) continue;
            return valueInfo;
        }
        return null;
    }

    private static List<IntermediateOperation> importOperationInputs(Onnx.NodeProto node, Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
        return node.getInputList().stream().map(nodeName -> GraphImporter.importOperation(nodeName, onnxGraph, intermediateGraph)).collect(Collectors.toList());
    }

    private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
        for (String outputName : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) {
            IntermediateOperation operation = intermediateGraph.get(outputName);
            Onnx.ValueInfoProto onnxNode = GraphImporter.getOutputNode(outputName, onnxGraph);
            OrderedTensorType type = operation.type().orElseThrow(() -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
            TypeConverter.verifyType(onnxNode.getType(), type);
        }
    }

    private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
        boolean hasPortNumber = nodeName.contains(":");
        for (Onnx.NodeProto node : graph.getNodeList()) {
            if (hasPortNumber) {
                for (String outputName : node.getOutputList()) {
                    if (!outputName.equals(nodeName)) continue;
                    return node;
                }
                continue;
            }
            if (!node.getName().equals(nodeName)) continue;
            return node;
        }
        throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
    }
}

