/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.yolean.Exceptions;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

public class TensorFlowImporter {
    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public TensorFlowModel importModel(String modelDir) {
        try (SavedModelBundle model = SavedModelBundle.load((String)modelDir, (String[])new String[]{"serve"});){
            TensorFlowModel tensorFlowModel = this.importModel(model);
            return tensorFlowModel;
        }
        catch (IllegalArgumentException e) {
            throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
        }
    }

    public TensorFlowModel importModel(File modelDir) {
        return this.importModel(modelDir.toString());
    }

    public TensorFlowModel importModel(SavedModelBundle model) {
        try {
            return TensorFlowImporter.importGraph(MetaGraphDef.parseFrom((byte[])model.metaGraphDef()), model);
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
        }
    }

    private static TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle bundle) {
        TensorFlowModel model = new TensorFlowModel();
        OperationIndex index = new OperationIndex();
        TensorFlowImporter.importSignatures(graph, model);
        TensorFlowImporter.importNodes(graph, model, index);
        TensorFlowImporter.findDimensionNames(model, index);
        TensorFlowImporter.importExpressions(model, index, bundle);
        TensorFlowImporter.reportWarnings(model, index);
        return model;
    }

    private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) {
        for (Map.Entry signatureEntry : graph.getSignatureDefMap().entrySet()) {
            String signatureName = (String)signatureEntry.getKey();
            TensorFlowModel.Signature signature = model.signature(signatureName);
            Map inputInfoMap = ((SignatureDef)signatureEntry.getValue()).getInputsMap();
            for (Map.Entry input : inputInfoMap.entrySet()) {
                String inputName = (String)input.getKey();
                signature.input(inputName, TensorFlowImporter.namePartOf(((TensorInfo)input.getValue()).getName()));
            }
            Map outputInfoMap = ((SignatureDef)signatureEntry.getValue()).getOutputsMap();
            for (Map.Entry output : outputInfoMap.entrySet()) {
                String outputName = (String)output.getKey();
                signature.output(outputName, TensorFlowImporter.namePartOf(((TensorInfo)output.getValue()).getName()));
            }
        }
    }

    private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) {
        for (TensorFlowModel.Signature signature : model.signatures().values()) {
            for (String inputName : signature.inputs().values()) {
                if (!inputName.equals(operation.node().getName())) continue;
                return true;
            }
        }
        return false;
    }

    private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) {
        for (TensorFlowModel.Signature signature : model.signatures().values()) {
            for (String outputName : signature.outputs().values()) {
                if (!outputName.equals(operation.node().getName())) continue;
                return true;
            }
        }
        return false;
    }

    private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) {
        for (TensorFlowModel.Signature signature : model.signatures().values()) {
            for (String outputName : signature.outputs().values()) {
                TensorFlowImporter.importNode(outputName, graph.getGraphDef(), index);
            }
        }
    }

    private static TensorFlowOperation importNode(String name, GraphDef graph, OperationIndex index) {
        if (index.alreadyImported(name)) {
            return index.get(name);
        }
        NodeDef node = TensorFlowImporter.getTensorFlowNodeFromGraph(TensorFlowImporter.namePartOf(name), graph);
        List<TensorFlowOperation> inputs = TensorFlowImporter.importNodeInputs(node, graph, index);
        TensorFlowOperation operation = OperationMapper.get(node, inputs, TensorFlowImporter.portPartOf(name));
        index.put(name, operation);
        List<TensorFlowOperation> controlInputs = TensorFlowImporter.importControlInputs(node, graph, index);
        if (controlInputs.size() > 0) {
            operation.setControlInputs(controlInputs);
        }
        return operation;
    }

    private static List<TensorFlowOperation> importNodeInputs(NodeDef node, GraphDef graph, OperationIndex index) {
        return node.getInputList().stream().filter(name -> !TensorFlowImporter.isControlDependency(name)).map(name -> TensorFlowImporter.importNode(name, graph, index)).collect(Collectors.toList());
    }

    private static List<TensorFlowOperation> importControlInputs(NodeDef node, GraphDef graph, OperationIndex index) {
        return node.getInputList().stream().filter(name -> TensorFlowImporter.isControlDependency(name)).map(name -> TensorFlowImporter.importNode(name, graph, index)).collect(Collectors.toList());
    }

    private static boolean isControlDependency(String name) {
        return name.startsWith("^");
    }

    private static void findDimensionNames(TensorFlowModel model, OperationIndex index) {
        DimensionRenamer renamer = new DimensionRenamer();
        for (TensorFlowModel.Signature signature : model.signatures().values()) {
            for (String output : signature.outputs().values()) {
                TensorFlowImporter.addDimensionNameConstraints(index.get(output), renamer);
            }
        }
        renamer.solve();
        for (TensorFlowModel.Signature signature : model.signatures().values()) {
            for (String output : signature.outputs().values()) {
                TensorFlowImporter.renameDimensions(index.get(output), renamer);
            }
        }
    }

    private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) {
        if (operation.type().isPresent()) {
            operation.inputs().forEach(input -> TensorFlowImporter.addDimensionNameConstraints(input, renamer));
            operation.addDimensionNameConstraints(renamer);
        }
    }

    private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) {
        if (operation.type().isPresent()) {
            operation.inputs().forEach(input -> TensorFlowImporter.renameDimensions(input, renamer));
            operation.renameDimensions(renamer);
        }
    }

    private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) {
        for (TensorFlowModel.Signature signature : model.signatures().values()) {
            for (String outputName : signature.outputs().values()) {
                try {
                    Optional<TensorFunction> function = TensorFlowImporter.importExpression(index.get(outputName), model, bundle);
                    if (function.isPresent()) continue;
                    signature.skippedOutput(outputName, "No valid output function could be found.");
                }
                catch (IllegalArgumentException e) {
                    signature.skippedOutput(outputName, Exceptions.toMessageString((Throwable)e));
                }
            }
        }
    }

    private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) {
        if (!operation.type().isPresent()) {
            return Optional.empty();
        }
        if (operation.isConstant()) {
            return TensorFlowImporter.importConstant(model, operation, bundle);
        }
        TensorFlowImporter.importInputExpressions(operation, model, bundle);
        TensorFlowImporter.importRankingExpression(model, operation);
        TensorFlowImporter.importInputExpression(model, operation);
        TensorFlowImporter.importMacroExpression(model, operation);
        return operation.function();
    }

    private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) {
        operation.inputs().forEach(input -> TensorFlowImporter.importExpression(input, model, bundle));
    }

    private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) {
        if (operation.macro().isPresent()) {
            TensorFunction function = operation.macro().get();
            try {
                model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
            }
            catch (ParseException e) {
                throw new RuntimeException("Tensorflow function " + function + " cannot be parsed as a ranking expression", e);
            }
        }
    }

    private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, SavedModelBundle bundle) {
        com.yahoo.tensor.Tensor tensor;
        String name = operation.vespaName();
        if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
            return operation.function();
        }
        if (operation.getConstantValue().isPresent()) {
            Value value = operation.getConstantValue().get();
            if (!(value instanceof TensorValue)) {
                return operation.function();
            }
            tensor = value.asTensor();
        } else {
            Session.Runner fetched = bundle.session().runner().fetch(operation.node().getName());
            List importedTensors = fetched.run();
            if (importedTensors.size() != 1) {
                throw new IllegalStateException("Expected 1 tensor from fetching " + operation.node().getName() + ", but got " + importedTensors.size());
            }
            tensor = TensorConverter.toVespaTensor((Tensor)importedTensors.get(0), operation.type().get());
            operation.setConstantValue(new TensorValue(tensor));
        }
        if (tensor.type().rank() == 0 || tensor.size() <= 1L) {
            model.smallConstant(operation.vespaName(), tensor);
        } else {
            model.largeConstant(operation.vespaName(), tensor);
        }
        return operation.function();
    }

    private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
        if (operation.function().isPresent()) {
            String name = operation.node().getName();
            if (!model.expressions().containsKey(operation.node().getName())) {
                OrderedTensorType standardNamingType;
                OrderedTensorType operationType;
                TensorFunction function = operation.function().get();
                if (TensorFlowImporter.isSignatureOutput(model, operation) && !(operationType = operation.type().get()).equals(standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node()))) {
                    List<String> renameFrom = operationType.dimensionNames();
                    List<String> renameTo = standardNamingType.dimensionNames();
                    function = new Rename(function, renameFrom, renameTo);
                }
                try {
                    model.expression(name, new RankingExpression(name, function.toString()));
                }
                catch (ParseException e) {
                    throw new RuntimeException("Tensorflow function " + function + " cannot be parsed as a ranking expression", e);
                }
            }
        }
    }

    private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) {
        if (operation.isInput() && TensorFlowImporter.isSignatureInput(model, operation)) {
            OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node());
            model.argument(operation.node().getName(), standardNamingConvention.type());
            model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
        }
    }

    private static void reportWarnings(TensorFlowModel model, OperationIndex index) {
        for (TensorFlowModel.Signature signature : model.signatures().values()) {
            for (String output : signature.outputs().values()) {
                TensorFlowImporter.reportWarnings(index.get(output), signature);
            }
        }
    }

    private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) {
        for (String warning : operation.warnings()) {
            signature.importWarning(warning);
        }
    }

    private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) {
        for (NodeDef node : graph.getNodeList()) {
            if (!node.getName().equals(name)) continue;
            return node;
        }
        throw new IllegalArgumentException("Could not find node '" + name + "'");
    }

    private static String namePartOf(String name) {
        name = name.startsWith("^") ? name.substring(1) : name;
        return name.split(":")[0];
    }

    private static int portPartOf(String name) {
        int i = name.indexOf(":");
        return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
    }

    private static class OperationIndex {
        private final Map<String, TensorFlowOperation> index = new HashMap<String, TensorFlowOperation>();

        private OperationIndex() {
        }

        public TensorFlowOperation put(String key, TensorFlowOperation operation) {
            return this.index.put(key, operation);
        }

        public TensorFlowOperation get(String key) {
            return this.index.get(key);
        }

        public boolean alreadyImported(String key) {
            return this.index.containsKey(key);
        }
    }
}

