/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.OperationMapper;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TypedTensorFunction;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.yolean.Exceptions;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
import org.tensorflow.framework.TensorShapeProto;

public class TensorFlowImporter {
    private final OperationMapper operationMapper = new OperationMapper();

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public ImportResult importModel(String modelDir) {
        try (SavedModelBundle model = SavedModelBundle.load((String)modelDir, (String[])new String[]{"serve"});){
            ImportResult importResult = this.importGraph(MetaGraphDef.parseFrom((byte[])model.metaGraphDef()), model);
            return importResult;
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) {
        try (SavedModelBundle model = SavedModelBundle.load((String)modelDir, (String[])new String[]{"serve"});){
            MetaGraphDef graph = MetaGraphDef.parseFrom((byte[])model.metaGraphDef());
            SignatureDef signature = (SignatureDef)graph.getSignatureDefMap().get(inputSignatureName);
            ImportResult result = new ImportResult();
            this.importInputs(signature.getInputsMap(), result);
            result.add(new RankingExpression(nodeName, this.importNode(nodeName, graph.getGraphDef(), model, result)));
            ImportResult importResult = result;
            return importResult;
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
        }
    }

    private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) {
        ImportResult result = new ImportResult();
        for (Map.Entry signatureEntry : graph.getSignatureDefMap().entrySet()) {
            this.importInputs(((SignatureDef)signatureEntry.getValue()).getInputsMap(), result);
            for (Map.Entry output : ((SignatureDef)signatureEntry.getValue()).getOutputsMap().entrySet()) {
                try {
                    ExpressionNode node = this.importOutput((TensorInfo)output.getValue(), graph.getGraphDef(), model, result);
                    result.add(new RankingExpression((String)output.getKey(), node));
                }
                catch (IllegalArgumentException e) {
                    result.warn("Skipping output '" + ((TensorInfo)output.getValue()).getName() + "' of signature '" + ((SignatureDef)signatureEntry.getValue()).getMethodName() + "': " + Exceptions.toMessageString((Throwable)e));
                }
            }
        }
        return result;
    }

    private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult result) {
        inputInfoMap.forEach((key, value) -> result.set(this.nameOf(value.getName()), this.importTensorType(value.getTensorShape())));
    }

    private TensorType importTensorType(TensorShapeProto tensorShape) {
        TensorType.Builder b = new TensorType.Builder();
        for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) {
            int dimensionSize = (int)dimension.getSize();
            if (dimensionSize >= 0) {
                b.indexed("d" + b.rank(), dimensionSize);
                continue;
            }
            b.indexed("d" + b.rank());
        }
        return b.build();
    }

    private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) {
        return this.importNode(this.nameOf(output.getName()), graph, model, result);
    }

    private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) {
        TensorFunction function = this.importNode(this.getNode(nodeName, graph), graph, model, result).function();
        return new TensorFunctionNode(function);
    }

    private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
        return this.tensorFunctionOf(tfNode, graph, model, result);
    }

    private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
        switch (tfNode.getOp().toLowerCase()) {
            case "add": 
            case "add_n": {
                return this.operationMapper.join(this.importArguments(tfNode, graph, model, result), ScalarFunctions.add());
            }
            case "acos": {
                return this.operationMapper.map(this.importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
            }
            case "placeholder": {
                return this.operationMapper.placeholder(tfNode, result);
            }
            case "identity": {
                return this.operationMapper.identity(tfNode, model, result);
            }
            case "matmul": {
                return this.operationMapper.matmul(this.importArguments(tfNode, graph, model, result));
            }
            case "softmax": {
                return this.operationMapper.softmax(this.importArguments(tfNode, graph, model, result));
            }
        }
        throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
    }

    private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
        return tfNode.getInputList().stream().map(argNode -> this.importNode(this.getNode(this.nameOf((String)argNode), graph), graph, model, result)).collect(Collectors.toList());
    }

    private NodeDef getNode(String name, GraphDef graph) {
        return graph.getNodeList().stream().filter(node -> node.getName().equals(name)).findFirst().orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'"));
    }

    private String nameOf(String name) {
        return name.split(":")[0];
    }
}

