/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.models.evaluation;

import ai.vespa.models.evaluation.BindingExtractor;
import ai.vespa.models.evaluation.Constant;
import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.FunctionReference;
import ai.vespa.models.evaluation.LazyArrayContext;
import ai.vespa.models.evaluation.OnnxModel;
import com.yahoo.api.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.stream.Collectors;

@Beta
public class Model
implements AutoCloseable {
    private static final Logger logger = Logger.getLogger(Model.class.getName());
    private static final String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_";
    private final String name;
    private final List<ExpressionFunction> functions;
    private final List<ExpressionFunction> publicFunctions;
    private final Map<FunctionReference, ExpressionFunction> referencedFunctions;
    private final Map<String, LazyArrayContext> contextPrototypes;
    private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer();
    private final List<Runnable> closeActions;

    public Model(String name, Collection<ExpressionFunction> functions) {
        this(name, functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), Map.of(), Map.of(), List.of(), List.of());
    }

    Model(String name, Map<FunctionReference, ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions, Map<String, TensorType> declaredTypes, List<Constant> constants, List<OnnxModel> onnxModels) {
        this.name = name;
        BindingExtractor bindingExtractor = new BindingExtractor(referencedFunctions, onnxModels);
        LinkedHashMap<String, LazyArrayContext> contextBuilder = new LinkedHashMap<String, LazyArrayContext>();
        for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
            try {
                RankingExpression body = function.getValue().getBody();
                body.setRoot(new OnnxReplacer(onnxModels, declaredTypes).transform(body.getRoot(), null));
                LazyArrayContext context = new LazyArrayContext(function.getValue(), bindingExtractor, referencedFunctions, constants, this);
                contextBuilder.put(function.getValue().getName(), context);
                if (function.getValue().returnType().isEmpty()) {
                    functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
                }
                for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) {
                    OnnxModel onnxModel = entry.getValue();
                    for (Map.Entry<String, TensorType> input : onnxModel.inputs().entrySet()) {
                        functions.put(function.getKey(), function.getValue().withArgument(input.getKey(), input.getValue()));
                    }
                }
                for (String argument : context.arguments()) {
                    if (function.getValue().getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)) {
                        if (function.getValue().arguments().contains(argument)) continue;
                        functions.put(function.getKey(), function.getValue().withArgument(argument));
                        continue;
                    }
                    if (function.getValue().getArgumentType(argument) != null) continue;
                    TensorType type = declaredTypes.getOrDefault(argument, TensorType.empty);
                    functions.put(function.getKey(), function.getValue().withArgument(argument, type));
                }
            }
            catch (RuntimeException e) {
                throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
            }
        }
        this.contextPrototypes = Map.copyOf(contextBuilder);
        this.functions = List.copyOf(functions.entrySet().stream().map(f -> this.optimize((ExpressionFunction)f.getValue(), this.contextPrototypes.get(((FunctionReference)f.getKey()).functionName()))).collect(Collectors.toList()));
        this.publicFunctions = functions.values().stream().filter(f -> !f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)).toList();
        this.referencedFunctions = Map.copyOf(referencedFunctions);
        this.closeActions = onnxModels.stream().map(o -> o::close).toList();
    }

    private ExpressionFunction optimize(ExpressionFunction function, ContextIndex context) {
        this.expressionOptimizer.optimize(function.getBody(), context);
        return function;
    }

    public String name() {
        return this.name;
    }

    public List<ExpressionFunction> functions() {
        return this.publicFunctions;
    }

    private LazyArrayContext requireContextPrototype(String name) {
        LazyArrayContext context = this.contextPrototypes.get(name);
        if (context == null) {
            throw new IllegalArgumentException("No function named '" + name + "' in " + this + ". Available functions: " + this.functions.stream().map(ExpressionFunction::getName).collect(Collectors.joining(", ")));
        }
        return context;
    }

    ExpressionFunction function(String name) {
        for (ExpressionFunction function : this.functions) {
            if (!function.getName().equals(name)) continue;
            return function;
        }
        return null;
    }

    Map<FunctionReference, ExpressionFunction> referencedFunctions() {
        return Map.copyOf(this.referencedFunctions);
    }

    ExpressionFunction requireReferencedFunction(FunctionReference reference) {
        ExpressionFunction function = this.referencedFunctions.get(reference);
        if (function == null) {
            throw new IllegalArgumentException("No " + reference + " in " + this + ". References: " + this.referencedFunctions.keySet().stream().map(FunctionReference::serialForm).collect(Collectors.joining(", ")));
        }
        return function;
    }

    public FunctionEvaluator evaluatorOf(String ... names) {
        if (names.length == 0) {
            if (this.functions.size() > 1) {
                this.throwUndeterminedFunction("More than one function is available in " + this + ", but no name is given");
            }
            return this.evaluatorOf(this.functions.get(0));
        }
        if (names.length == 1) {
            List<ExpressionFunction> functionsEndingByName;
            String name = names[0];
            ExpressionFunction function = this.function(name);
            if (function != null) {
                return this.evaluatorOf(function);
            }
            List<ExpressionFunction> functionsStartingByName = this.functions.stream().filter(f -> f.getName().startsWith(name + ".")).toList();
            if (functionsStartingByName.size() == 1) {
                return this.evaluatorOf(functionsStartingByName.get(0));
            }
            if (functionsStartingByName.size() > 1) {
                this.throwUndeterminedFunction("Multiple functions start by '" + name + "' in " + this);
            }
            if ((functionsEndingByName = this.functions.stream().filter(f -> f.getName().endsWith("." + name)).toList()).size() == 1) {
                return this.evaluatorOf(functionsEndingByName.get(0));
            }
            if (functionsEndingByName.size() > 1) {
                this.throwUndeterminedFunction("Multiple functions called '" + name + "' in " + this);
            }
            if (name.startsWith("serving_default")) {
                return this.evaluatorOf("default" + name.substring("serving_default".length()));
            }
            if (name.startsWith("default.")) {
                return this.evaluatorOf(name.substring("default.".length()));
            }
            this.throwUndeterminedFunction("No function '" + name + "' in " + this);
        } else if (names.length == 2) {
            return this.evaluatorOf(names[0] + "." + names[1]);
        }
        throw new IllegalArgumentException("No more than 2 names can be given when choosing a function, got " + Arrays.toString(names));
    }

    private FunctionEvaluator evaluatorOf(ExpressionFunction function) {
        return new FunctionEvaluator(function, this.requireContextPrototype(function.getName()).copy());
    }

    private void throwUndeterminedFunction(String message) {
        throw new IllegalArgumentException(message + ". Available functions: " + this.functions.stream().map(ExpressionFunction::getName).collect(Collectors.joining(", ")));
    }

    public String toString() {
        return "model '" + this.name + "'";
    }

    @Override
    public void close() {
        this.closeActions.forEach(Runnable::run);
    }

    static class OnnxReplacer
    extends ExpressionTransformer<TransformContext> {
        private final List<OnnxModel> onnxModels;
        private final Map<String, TensorType> declaredTypes;

        private OnnxModel getModel(String name) {
            for (OnnxModel m : this.onnxModels) {
                if (!m.name().equals(name)) continue;
                return m;
            }
            return null;
        }

        public OnnxReplacer(List<OnnxModel> onnxModels, Map<String, TensorType> declaredTypes) {
            this.onnxModels = onnxModels;
            this.declaredTypes = declaredTypes;
        }

        public ExpressionNode transform(ExpressionNode node, TransformContext context) {
            ReferenceNode r;
            Reference ref;
            ExpressionNode orig = node;
            if (node instanceof ReferenceNode && ((ref = (r = (ReferenceNode)node).reference()).name().equals("onnx") || ref.name().equals("onnxModel"))) {
                logger.fine("consider replacing: " + ref);
                OnnxModel m = this.getModel(ref.simpleArgument().orElse(null));
                if (m != null) {
                    m.load();
                    ExpressionNode expr = m.getExpressionForOutput(ref.output());
                    if (expr != null) {
                        logger.fine("Replacing " + node + " => " + expr);
                        node = expr;
                        for (OnnxModel.InputSpec inputSpec : m.inputSpecs) {
                            TensorType old = this.declaredTypes.put(inputSpec.source, inputSpec.wantedType);
                            if (old == null || old.equals((Object)inputSpec.wantedType)) continue;
                            throw new IllegalArgumentException("Conflicting types needed for " + inputSpec.source + "; " + old + " != " + inputSpec.wantedType);
                        }
                    } else {
                        logger.fine("no output named " + ref.output() + " from " + m);
                    }
                } else {
                    logger.fine("no onnx model named " + ref.simpleArgument());
                }
            }
            if (node instanceof CompositeNode) {
                CompositeNode c = (CompositeNode)node;
                node = this.transformChildren(c, context);
            }
            if (node != orig) {
                logger.fine("transformed: " + orig + " => " + node);
            }
            return node;
        }
    }
}

