/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.OperationMapper;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TypedTensorFunction;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.yolean.Exceptions;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
import org.tensorflow.framework.TensorShapeProto;

public class TensorFlowImporter {
    private final OperationMapper operationMapper = new OperationMapper();

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public ImportResult importModel(String modelDir) {
        try (SavedModelBundle model = SavedModelBundle.load((String)modelDir, (String[])new String[]{"serve"});){
            ImportResult importResult = this.importModel(model);
            return importResult;
        }
        catch (IllegalArgumentException e) {
            throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
        }
    }

    public ImportResult importModel(SavedModelBundle model) {
        try {
            return this.importGraph(MetaGraphDef.parseFrom((byte[])model.metaGraphDef()), model);
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
        }
    }

    private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) {
        ImportResult result = new ImportResult();
        for (Map.Entry signatureEntry : graph.getSignatureDefMap().entrySet()) {
            ImportResult.Signature signature = result.signature((String)signatureEntry.getKey());
            this.importInputs(((SignatureDef)signatureEntry.getValue()).getInputsMap(), signature);
            for (Map.Entry output : ((SignatureDef)signatureEntry.getValue()).getOutputsMap().entrySet()) {
                String outputName = (String)output.getKey();
                try {
                    NodeDef node = this.getNode(this.nameOf(((TensorInfo)output.getValue()).getName()), graph.getGraphDef());
                    this.importNode(node, graph.getGraphDef(), model, result);
                    signature.output(outputName, this.nameOf(((TensorInfo)output.getValue()).getName()));
                }
                catch (IllegalArgumentException e) {
                    result.warn("Skipping output '" + outputName + "' of " + signature + ": " + Exceptions.toMessageString((Throwable)e));
                }
            }
        }
        return result;
    }

    private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) {
        inputInfoMap.forEach((key, value) -> {
            String argumentName = this.nameOf(value.getName());
            TensorType argumentType = this.importTensorType(value.getTensorShape());
            signature.owner().argument(argumentName, argumentType);
            signature.input((String)key, argumentName);
        });
    }

    private TensorType importTensorType(TensorShapeProto tensorShape) {
        TensorType.Builder b = new TensorType.Builder();
        for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) {
            int dimensionSize = (int)dimension.getSize();
            if (dimensionSize >= 0) {
                b.indexed("d" + b.rank(), (long)dimensionSize);
                continue;
            }
            b.indexed("d" + b.rank());
        }
        return b.build();
    }

    private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
        TypedTensorFunction function = this.tensorFunctionOf(tfNode, graph, model, result);
        result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function())));
        return function;
    }

    private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
        switch (tfNode.getOp().toLowerCase()) {
            case "add": 
            case "add_n": {
                return this.operationMapper.join(this.importArguments(tfNode, graph, model, result), ScalarFunctions.add());
            }
            case "acos": {
                return this.operationMapper.map(this.importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
            }
            case "placeholder": {
                return this.operationMapper.placeholder(tfNode, result);
            }
            case "identity": {
                return this.operationMapper.identity(tfNode, model, result);
            }
            case "matmul": {
                return this.operationMapper.matmul(this.importArguments(tfNode, graph, model, result));
            }
            case "softmax": {
                return this.operationMapper.softmax(this.importArguments(tfNode, graph, model, result));
            }
        }
        throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
    }

    private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
        return tfNode.getInputList().stream().map(argNode -> this.importNode(this.getNode(this.nameOf((String)argNode), graph), graph, model, result)).collect(Collectors.toList());
    }

    private NodeDef getNode(String name, GraphDef graph) {
        return graph.getNodeList().stream().filter(node -> node.getName().equals(name)).findFirst().orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'"));
    }

    private String nameOf(String name) {
        return name.split(":")[0];
    }
}

