/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.models.evaluation;

import ai.vespa.models.evaluation.LazyArrayContext;
import ai.vespa.models.evaluation.OnnxModel;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

public class FunctionEvaluator {
    private final ExpressionFunction function;
    private final LazyArrayContext context;
    private boolean evaluated = false;

    FunctionEvaluator(ExpressionFunction function, LazyArrayContext context) {
        this.function = function;
        this.context = context;
    }

    public FunctionEvaluator bind(String name, Tensor value) {
        if (this.evaluated) {
            throw new IllegalStateException("Cannot bind a new value in a used evaluator");
        }
        TensorType requiredType = this.function.getArgumentType(name);
        if (requiredType == null) {
            throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + String.valueOf(this.function) + ". Expected arguments: " + this.function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()).map(e -> (String)e.getKey() + ": " + String.valueOf(e.getValue())).collect(Collectors.joining(", ")));
        }
        if (!value.type().isAssignableTo(requiredType)) {
            throw new IllegalArgumentException("'" + name + "' must be of type " + String.valueOf(requiredType) + ", not " + String.valueOf(value.type()));
        }
        this.context.put(name, (Value)new TensorValue(value));
        return this;
    }

    public FunctionEvaluator bind(String name, double value) {
        return this.bind(name, Tensor.Builder.of((TensorType)TensorType.empty).cell(value, new long[0]).build());
    }

    public FunctionEvaluator bind(String name, String value) {
        if (this.evaluated) {
            throw new IllegalStateException("Cannot bind a new value in a used evaluator");
        }
        this.context.put(name, (Value)new StringValue(value));
        return this;
    }

    public FunctionEvaluator setMissingValue(Tensor value) {
        if (this.evaluated) {
            throw new IllegalStateException("Cannot change the missing value in a used evaluator");
        }
        this.context.setMissingValue(value);
        return this;
    }

    public FunctionEvaluator setMissingValue(double value) {
        return this.setMissingValue(Tensor.Builder.of((TensorType)TensorType.empty).cell(value, new long[0]).build());
    }

    public Tensor evaluate() {
        this.function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()).forEach(argument -> this.checkArgument((String)argument.getKey(), (TensorType)argument.getValue()));
        this.evaluated = true;
        this.evaluateOnnxModels();
        return this.function.getBody().evaluate((Context)this.context).asTensor();
    }

    private void checkArgument(String name, TensorType type) {
        if (this.context.isMissing(name)) {
            throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + String.valueOf(type));
        }
        if (!this.context.get(name).type().isAssignableTo(type)) {
            throw new IllegalStateException("Argument '" + name + "' must be bound to a value of type " + String.valueOf(type));
        }
    }

    private void evaluateOnnxModels() {
        for (Map.Entry<String, OnnxModel> entry : this.context().onnxModels().entrySet()) {
            String onnxFeature = entry.getKey();
            String outputName = this.function.getName();
            int idx = onnxFeature.indexOf(").");
            if (idx > 0 && idx + 2 < onnxFeature.length()) {
                outputName = onnxFeature.substring(idx + 2);
            }
            OnnxModel onnxModel = entry.getValue();
            if (!this.context.get(onnxFeature).equals((Object)this.context.defaultValue())) continue;
            HashMap<String, Tensor> inputs = new HashMap<String, Tensor>();
            for (Map.Entry<String, TensorType> input : onnxModel.inputs().entrySet()) {
                inputs.put(input.getKey(), this.context.get(input.getKey()).asTensor());
            }
            Tensor result = onnxModel.evaluate(inputs, outputName);
            this.context.put(onnxFeature, (Value)new TensorValue(result));
        }
    }

    public ExpressionFunction function() {
        return this.function;
    }

    public LazyArrayContext context() {
        return this.context;
    }
}

