/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.models.evaluation;

import ai.vespa.models.evaluation.FunctionReference;
import ai.vespa.models.evaluation.OnnxModel;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

class BindingExtractor {
    private final Map<FunctionReference, ExpressionFunction> referencedFunctions;
    private final List<OnnxModel> onnxModels;
    private final Map<FunctionReference, FunctionInfo> functionsInfo = new LinkedHashMap<FunctionReference, FunctionInfo>();

    public BindingExtractor(Map<FunctionReference, ExpressionFunction> referencedFunctions, List<OnnxModel> onnxModels) {
        this.referencedFunctions = referencedFunctions;
        this.onnxModels = onnxModels;
    }

    FunctionInfo extractFrom(FunctionReference ref) {
        if (this.functionsInfo.containsKey(ref)) {
            return this.functionsInfo.get(ref);
        }
        ExpressionFunction function = this.referencedFunctions.get(ref);
        FunctionInfo result = this.extractFrom(function);
        this.functionsInfo.put(ref, result);
        return result;
    }

    FunctionInfo extractFrom(ExpressionFunction function) {
        if (function == null) {
            return null;
        }
        ExpressionNode functionNode = function.getBody().getRoot();
        return this.extractBindTargets(functionNode);
    }

    private FunctionInfo extractBindTargets(ExpressionNode node) {
        FunctionInfo result = new FunctionInfo();
        if (this.isFunctionReference(node)) {
            Optional<FunctionReference> opt = FunctionReference.fromSerial(node.toString());
            if (opt.isEmpty()) {
                throw new IllegalArgumentException("Could not extract function " + String.valueOf(node) + " from serialized form '" + node.toString() + "'");
            }
            FunctionReference reference = opt.get();
            result.bindTargets.add(reference.serialForm());
            FunctionInfo subInfo = this.extractFrom(reference);
            if (subInfo == null) {
                result.arguments.add(reference.serialForm());
            } else {
                result.merge(subInfo);
            }
            return result;
        }
        if (node instanceof TensorFunctionNode) {
            TensorFunctionNode tfn = (TensorFunctionNode)node;
            for (ExpressionNode child : tfn.children()) {
                result.merge(this.extractBindTargets(child));
            }
            tfn.withTransformedExpressions(expr -> {
                result.merge(this.extractBindTargets((ExpressionNode)expr));
                return expr;
            });
            TensorFunction f = tfn.function();
            if (f instanceof Generate) {
                TensorType tt = f.type(null);
                for (TensorType.Dimension dim : tt.dimensions()) {
                    result.removeTarget(dim.name());
                }
            }
            return result;
        }
        if (this.isOnnx(node)) {
            return this.extractOnnxTargets(node);
        }
        if (this.isConstant(node)) {
            result.bindTargets.add(node.toString());
            return result;
        }
        if (node instanceof ReferenceNode) {
            result.bindTargets.add(node.toString());
            result.arguments.add(node.toString());
            return result;
        }
        if (node instanceof CompositeNode) {
            CompositeNode cNode = (CompositeNode)node;
            for (ExpressionNode child : cNode.children()) {
                result.merge(this.extractBindTargets(child));
            }
            return result;
        }
        if (node instanceof ConstantNode) {
            return result;
        }
        return result;
    }

    private FunctionInfo extractOnnxTargets(ExpressionNode node) {
        FunctionInfo result = new FunctionInfo();
        String onnxFeature = node.toString();
        result.bindTargets.add(onnxFeature);
        Optional<String> modelName = this.getArgument(node);
        if (modelName.isPresent()) {
            for (OnnxModel onnxModel : this.onnxModels) {
                if (!onnxModel.name().equals(modelName.get())) continue;
                onnxModel.load();
                for (String input : onnxModel.inputs().keySet()) {
                    result.bindTargets.add(input);
                    result.arguments.add(input);
                }
                result.onnxModelsInUse.put(onnxFeature, onnxModel);
                return result;
            }
        }
        result.arguments.add(onnxFeature);
        return result;
    }

    private Optional<String> getArgument(ExpressionNode node) {
        ReferenceNode reference;
        if (node instanceof ReferenceNode && (reference = (ReferenceNode)node).getArguments().size() > 0) {
            ExpressionNode arg = (ExpressionNode)reference.getArguments().expressions().get(0);
            if (arg instanceof ConstantNode) {
                return Optional.of(BindingExtractor.stripQuotes(arg.toString()));
            }
            if (arg instanceof ReferenceNode) {
                ReferenceNode refNode = (ReferenceNode)arg;
                return Optional.of(refNode.getName());
            }
        }
        return Optional.empty();
    }

    public static String stripQuotes(String s) {
        if (s.length() < 3) {
            return s;
        }
        int lastIdx = s.length() - 1;
        char first = s.charAt(0);
        char last = s.charAt(lastIdx);
        if (first == '\"' && last == '\"') {
            return s.substring(1, lastIdx);
        }
        if (first == '\'' && last == '\'') {
            return s.substring(1, lastIdx);
        }
        return s;
    }

    private boolean isFunctionReference(ExpressionNode node) {
        if (!(node instanceof ReferenceNode)) {
            return false;
        }
        ReferenceNode reference = (ReferenceNode)node;
        return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1;
    }

    private boolean isOnnx(ExpressionNode node) {
        if (!(node instanceof ReferenceNode)) {
            return false;
        }
        ReferenceNode reference = (ReferenceNode)node;
        return reference.getName().equals("onnx") || reference.getName().equals("onnxModel");
    }

    private boolean isConstant(ExpressionNode node) {
        if (!(node instanceof ReferenceNode)) {
            return false;
        }
        ReferenceNode reference = (ReferenceNode)node;
        return reference.getName().equals("constant") && reference.getArguments().size() == 1;
    }

    static class FunctionInfo {
        final Set<String> bindTargets = new LinkedHashSet<String>();
        final Set<String> arguments = new LinkedHashSet<String>();
        final Map<String, OnnxModel> onnxModelsInUse = new LinkedHashMap<String, OnnxModel>();

        FunctionInfo() {
        }

        void merge(FunctionInfo other) {
            this.bindTargets.addAll(other.bindTargets);
            this.arguments.addAll(other.arguments);
            this.onnxModelsInUse.putAll(other.onnxModelsInUse);
        }

        void removeTarget(String name) {
            this.bindTargets.remove(name);
            this.arguments.remove(name);
        }
    }
}

