/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.AttrValueConverter;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.OperationMapper;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TypedTensorFunction;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.TensorType;
import com.yahoo.yolean.Exceptions;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

public class TensorFlowImporter {
    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public TensorFlowModel importModel(String modelDir) {
        try (SavedModelBundle model = SavedModelBundle.load((String)modelDir, (String[])new String[]{"serve"});){
            TensorFlowModel tensorFlowModel = this.importModel(model);
            return tensorFlowModel;
        }
        catch (IllegalArgumentException e) {
            throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
        }
    }

    public TensorFlowModel importModel(File modelDir) {
        return this.importModel(modelDir.toString());
    }

    public TensorFlowModel importModel(SavedModelBundle model) {
        try {
            return this.importGraph(MetaGraphDef.parseFrom((byte[])model.metaGraphDef()), model);
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
        }
    }

    private TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle model) {
        TensorFlowModel result = new TensorFlowModel();
        for (Map.Entry signatureEntry : graph.getSignatureDefMap().entrySet()) {
            TensorFlowModel.Signature signature = result.signature((String)signatureEntry.getKey());
            this.importInputs(((SignatureDef)signatureEntry.getValue()).getInputsMap(), signature);
            for (Map.Entry output : ((SignatureDef)signatureEntry.getValue()).getOutputsMap().entrySet()) {
                String outputName = (String)output.getKey();
                try {
                    NodeDef node = this.getNode(TensorFlowImporter.namePartOf(((TensorInfo)output.getValue()).getName()), graph.getGraphDef());
                    Parameters params = this.createParameters(graph.getGraphDef(), model, result, signature, node, "");
                    Optional<TypedTensorFunction> outputFunction = this.importNode(params);
                    if (!outputFunction.isPresent()) {
                        throw new IllegalArgumentException(signature.importWarnings().stream().collect(Collectors.joining("\n")));
                    }
                    signature.output(outputName, TensorFlowImporter.namePartOf(((TensorInfo)output.getValue()).getName()));
                }
                catch (IllegalArgumentException e) {
                    signature.skippedOutput(outputName, Exceptions.toMessageString((Throwable)e));
                }
            }
        }
        return result;
    }

    private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) {
        inputInfoMap.forEach((key, value) -> {
            String argumentName = TensorFlowImporter.namePartOf(value.getName());
            TensorType argumentType = AttrValueConverter.toVespaTensorType(value.getTensorShape());
            signature.owner().argument(argumentName, argumentType);
            signature.input((String)key, argumentName);
        });
    }

    private Optional<TypedTensorFunction> importNode(Parameters params) {
        String nodeName = params.node().getName();
        if (params.imported().containsKey(nodeName)) {
            return Optional.of(params.imported().get(nodeName));
        }
        Optional<TypedTensorFunction> function = OperationMapper.map(params);
        if (!function.isPresent()) {
            return Optional.empty();
        }
        if (!this.controlDependenciesArePresent(params)) {
            return Optional.empty();
        }
        params.imported().put(nodeName, function.get());
        try {
            params.result().expression(nodeName, new RankingExpression(nodeName, function.get().function().toString()));
            return function;
        }
        catch (ParseException e) {
            throw new RuntimeException("Tensorflow function " + function.get().function() + " cannot be parsed as a ranking expression", e);
        }
    }

    private boolean controlDependenciesArePresent(Parameters params) {
        return params.node().getInputList().stream().filter(TensorFlowImporter::isControlDependency).map(nodeName -> this.importNode(params.copy(this.getNode(TensorFlowImporter.namePartOf(nodeName), params.graph()), TensorFlowImporter.indexPartOf(nodeName)))).allMatch(Optional::isPresent);
    }

    private static boolean isControlDependency(String nodeName) {
        return nodeName.startsWith("^");
    }

    private List<Optional<TypedTensorFunction>> importArguments(Parameters params) {
        return params.node().getInputList().stream().filter(nodeName -> !TensorFlowImporter.isControlDependency(nodeName)).map(nodeName -> this.importNode(params.copy(this.getNode(TensorFlowImporter.namePartOf(nodeName), params.graph()), TensorFlowImporter.indexPartOf(nodeName)))).collect(Collectors.toList());
    }

    private NodeDef getNode(String name, GraphDef graph) {
        return graph.getNodeList().stream().filter(node -> node.getName().equals(name)).findFirst().orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'"));
    }

    private static String namePartOf(String name) {
        name = name.startsWith("^") ? name.substring(1) : name;
        return name.split(":")[0];
    }

    private static String indexPartOf(String name) {
        int i = name.indexOf(":");
        return i < 0 ? "" : name.substring(i + 1);
    }

    private Parameters createParameters(GraphDef graph, SavedModelBundle model, TensorFlowModel result, TensorFlowModel.Signature signature, NodeDef node, String port) {
        return new Parameters(this, graph, model, result, signature, new HashMap(), node, port);
    }

    static final class Parameters {
        private final TensorFlowImporter owner;
        private final GraphDef graph;
        private final SavedModelBundle model;
        private final TensorFlowModel result;
        private final TensorFlowModel.Signature signature;
        private final Map<String, TypedTensorFunction> imported;
        private final NodeDef node;
        private final String port;

        private Parameters(TensorFlowImporter owner, GraphDef graph, SavedModelBundle model, TensorFlowModel result, TensorFlowModel.Signature signature, Map<String, TypedTensorFunction> imported, NodeDef node, String port) {
            this.owner = owner;
            this.graph = graph;
            this.model = model;
            this.result = result;
            this.signature = signature;
            this.imported = imported;
            this.node = node;
            this.port = port;
        }

        GraphDef graph() {
            return this.graph;
        }

        SavedModelBundle model() {
            return this.model;
        }

        TensorFlowModel result() {
            return this.result;
        }

        TensorFlowModel.Signature signature() {
            return this.signature;
        }

        Map<String, TypedTensorFunction> imported() {
            return this.imported;
        }

        NodeDef node() {
            return this.node;
        }

        String port() {
            return this.port;
        }

        Parameters copy(NodeDef node, String port) {
            return new Parameters(this.owner, this.graph, this.model, this.result, this.signature, this.imported, node, port);
        }

        List<Optional<TypedTensorFunction>> inputs() {
            return this.owner.importArguments(this);
        }
    }
}

