/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.onnx;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel;
import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OperationMapper;
import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Argument;
import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Constant;
import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import onnx.Onnx;

public class OnnxImporter {
    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public OnnxModel importModel(String modelPath, String outputNode) {
        try (FileInputStream inputStream = new FileInputStream(modelPath);){
            Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
            OnnxModel onnxModel = this.importModel(model, outputNode);
            return onnxModel;
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
        }
    }

    public OnnxModel importModel(Onnx.ModelProto model, String outputNode) {
        return OnnxImporter.importGraph(model.getGraph(), outputNode);
    }

    private static OnnxModel importGraph(Onnx.GraphProto graph, String outputNode) {
        OnnxModel model = new OnnxModel(outputNode);
        OperationIndex index = new OperationIndex();
        OnnxOperation output = OnnxImporter.importNode(outputNode, graph, index);
        output.type().orElseThrow(() -> new IllegalArgumentException("Output of '" + outputNode + "' has no type.")).verifyType(OnnxImporter.getOutputNode(outputNode, graph).getType());
        OnnxImporter.findDimensionNames(output);
        OnnxImporter.importExpressions(output, model);
        return model;
    }

    private static OnnxOperation importNode(String nodeName, Onnx.GraphProto graph, OperationIndex index) {
        OnnxOperation operation;
        if (index.alreadyImported(nodeName)) {
            return index.get(nodeName);
        }
        if (OnnxImporter.isArgumentTensor(nodeName, graph)) {
            operation = new Argument(OnnxImporter.getArgumentTensor(nodeName, graph));
        } else if (OnnxImporter.isConstantTensor(nodeName, graph)) {
            operation = new Constant(OnnxImporter.getConstantTensor(nodeName, graph));
        } else {
            Onnx.NodeProto node = OnnxImporter.getNodeFromGraph(nodeName, graph);
            List<OnnxOperation> inputs = OnnxImporter.importNodeInputs(node, graph, index);
            operation = OperationMapper.get(node, inputs);
        }
        index.put(nodeName, operation);
        return operation;
    }

    private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
        Onnx.ValueInfoProto value = OnnxImporter.getArgumentTensor(name, graph);
        Onnx.TensorProto tensor = OnnxImporter.getConstantTensor(name, graph);
        return value != null && tensor == null;
    }

    private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
        Onnx.ValueInfoProto value = OnnxImporter.getArgumentTensor(name, graph);
        Onnx.TensorProto tensor = OnnxImporter.getConstantTensor(name, graph);
        return value != null && tensor != null;
    }

    private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
        for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
            if (!valueInfo.getName().equals(name)) continue;
            return valueInfo;
        }
        return null;
    }

    private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
        for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
            if (!tensorProto.getName().equals(name)) continue;
            return tensorProto;
        }
        return null;
    }

    private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
        return OnnxImporter.getOutputNode(name, graph) != null;
    }

    private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
        for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
            Onnx.NodeProto node = OnnxImporter.getNodeFromGraph(valueInfo.getName(), graph);
            if (!node.getName().equals(name)) continue;
            return valueInfo;
        }
        return null;
    }

    private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node, Onnx.GraphProto graph, OperationIndex index) {
        return node.getInputList().stream().map(nodeName -> OnnxImporter.importNode(nodeName, graph, index)).collect(Collectors.toList());
    }

    private static void findDimensionNames(OnnxOperation output) {
        DimensionRenamer renamer = new DimensionRenamer();
        OnnxImporter.addDimensionNameConstraints(output, renamer);
        renamer.solve();
        OnnxImporter.renameDimensions(output, renamer);
    }

    private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) {
        if (operation.type().isPresent()) {
            operation.inputs().forEach(input -> OnnxImporter.addDimensionNameConstraints(input, renamer));
            operation.addDimensionNameConstraints(renamer);
        }
    }

    private static void renameDimensions(OnnxOperation operation, DimensionRenamer renamer) {
        if (operation.type().isPresent()) {
            operation.inputs().forEach(input -> OnnxImporter.renameDimensions(input, renamer));
            operation.renameDimensions(renamer);
        }
    }

    private static void importExpressions(OnnxOperation output, OnnxModel model) {
        Optional<TensorFunction> function = OnnxImporter.importExpression(output, model);
        if (!function.isPresent()) {
            throw new IllegalArgumentException("No valid output function could be found.");
        }
    }

    private static Optional<TensorFunction> importExpression(OnnxOperation operation, OnnxModel model) {
        if (!operation.type().isPresent()) {
            return Optional.empty();
        }
        if (operation.isConstant()) {
            return OnnxImporter.importConstant(operation, model);
        }
        OnnxImporter.importInputExpressions(operation, model);
        OnnxImporter.importRankingExpression(operation, model);
        OnnxImporter.importInputExpression(operation, model);
        return operation.function();
    }

    private static void importInputExpressions(OnnxOperation operation, OnnxModel model) {
        operation.inputs().forEach(input -> OnnxImporter.importExpression(input, model));
    }

    private static Optional<TensorFunction> importConstant(OnnxOperation operation, OnnxModel model) {
        String name = operation.vespaName();
        if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
            return operation.function();
        }
        Value value = operation.getConstantValue().orElseThrow(() -> new IllegalArgumentException("Operation '" + operation.vespaName() + "' is constant but does not have a value."));
        if (!(value instanceof TensorValue)) {
            return operation.function();
        }
        Tensor tensor = value.asTensor();
        if (tensor.type().rank() == 0) {
            model.smallConstant(name, tensor);
        } else {
            model.largeConstant(name, tensor);
        }
        return operation.function();
    }

    private static void importRankingExpression(OnnxOperation operation, OnnxModel model) {
        if (operation.function().isPresent()) {
            String name = operation.vespaName();
            if (!model.expressions().containsKey(name)) {
                OrderedTensorType standardNamingType;
                OrderedTensorType operationType;
                TensorFunction function = operation.function().get();
                if (name.equals(model.output()) && !(operationType = operation.type().get()).equals(standardNamingType = OrderedTensorType.standardType(operationType))) {
                    List<String> renameFrom = operationType.dimensionNames();
                    List<String> renameTo = standardNamingType.dimensionNames();
                    function = new Rename(function, renameFrom, renameTo);
                }
                try {
                    model.expression(name, new RankingExpression(name, function.toString()));
                }
                catch (ParseException e) {
                    throw new RuntimeException("Tensorflow function " + function + " cannot be parsed as a ranking expression", e);
                }
            }
        }
    }

    private static void importInputExpression(OnnxOperation operation, OnnxModel model) {
        if (operation.isInput()) {
            OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
            model.argument(operation.vespaName(), standardNamingConvention.type());
            model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
        }
    }

    private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
        boolean hasPortNumber = nodeName.contains(":");
        for (Onnx.NodeProto node : graph.getNodeList()) {
            if (hasPortNumber) {
                for (String outputName : node.getOutputList()) {
                    if (!outputName.equals(nodeName)) continue;
                    return node;
                }
                continue;
            }
            if (!node.getName().equals(nodeName)) continue;
            return node;
        }
        throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
    }

    private static class OperationIndex {
        private final Map<String, OnnxOperation> index = new HashMap<String, OnnxOperation>();

        private OperationIndex() {
        }

        public OnnxOperation put(String key, OnnxOperation operation) {
            return this.index.put(key, operation);
        }

        public OnnxOperation get(String key) {
            return this.index.get(key);
        }

        public boolean alreadyImported(String key) {
            return this.index.containsKey(key);
        }

        public Collection<OnnxOperation> operations() {
            return this.index.values();
        }
    }
}

