/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.models.evaluation;

import ai.vespa.models.evaluation.Constant;
import ai.vespa.models.evaluation.FunctionReference;
import ai.vespa.models.evaluation.LazyValue;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.OnnxModel;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

public final class LazyArrayContext
extends Context
implements ContextIndex {
    private final ExpressionFunction function;
    private final IndexedBindings indexedBindings;

    private LazyArrayContext(ExpressionFunction function, IndexedBindings indexedBindings) {
        this.function = function;
        this.indexedBindings = indexedBindings.copy(this);
    }

    LazyArrayContext(ExpressionFunction function, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, List<OnnxModel> onnxModels, Model model) {
        this.function = function;
        this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, onnxModels, this, model);
    }

    public void setMissingValue(Tensor value) {
        this.indexedBindings.setMissingValue(value);
    }

    public void put(String name, Value value) {
        this.put((int)this.requireIndexOf(name), value);
    }

    public final void put(int index, double value) {
        this.put(index, (Value)DoubleValue.frozen((double)value));
    }

    public void put(int index, Value value) {
        this.indexedBindings.set(index, value.freeze());
    }

    public TensorType getType(Reference reference) {
        return this.get(this.requireIndexOf(reference.toString())).type();
    }

    public Value get(String name) {
        return this.get(this.requireIndexOf(name));
    }

    public Value get(int index) {
        return this.indexedBindings.get(index);
    }

    public double getDouble(int index) {
        return this.get(index).asDouble();
    }

    public int getIndex(String name) {
        return this.requireIndexOf(name);
    }

    public int size() {
        return this.indexedBindings.names().size();
    }

    public Set<String> names() {
        return this.indexedBindings.names();
    }

    public Set<String> arguments() {
        return this.indexedBindings.arguments();
    }

    public Map<String, OnnxModel> onnxModels() {
        return this.indexedBindings.onnxModels();
    }

    private Integer requireIndexOf(String name) {
        Integer index = this.indexedBindings.indexOf(name);
        if (index == null) {
            throw new IllegalArgumentException("Value '" + name + "' can not be bound in " + this);
        }
        return index;
    }

    boolean isMissing(String name) {
        return this.indexedBindings.indexOf(name) == null;
    }

    public Value defaultValue() {
        return this.indexedBindings.missingValue;
    }

    LazyArrayContext copy() {
        return new LazyArrayContext(this.function, this.indexedBindings);
    }

    private static class IndexedBindings {
        private final ImmutableMap<String, Integer> nameToIndex;
        private final ImmutableSet<String> arguments;
        private final Value[] values;
        private final ImmutableMap<String, OnnxModel> onnxModels;
        private static final Value missing = new DoubleValue(Double.NaN).freeze();
        private Value missingValue = new DoubleValue(Double.NaN).freeze();

        private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values, ImmutableSet<String> arguments, ImmutableMap<String, OnnxModel> onnxModels) {
            this.nameToIndex = nameToIndex;
            this.values = values;
            this.arguments = arguments;
            this.onnxModels = onnxModels;
        }

        IndexedBindings(ExpressionFunction function, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, List<OnnxModel> onnxModels, LazyArrayContext owner, Model model) {
            LinkedHashSet<String> bindTargets = new LinkedHashSet<String>();
            LinkedHashSet<String> arguments = new LinkedHashSet<String>();
            HashMap<String, OnnxModel> onnxModelsInUse = new HashMap<String, OnnxModel>();
            this.extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse);
            this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse);
            this.arguments = ImmutableSet.copyOf(arguments);
            this.values = new Value[bindTargets.size()];
            Arrays.fill(this.values, missing);
            int i = 0;
            ImmutableMap.Builder nameToIndexBuilder = new ImmutableMap.Builder();
            for (String string : bindTargets) {
                nameToIndexBuilder.put((Object)string, (Object)i++);
            }
            this.nameToIndex = nameToIndexBuilder.build();
            for (Constant constant : constants) {
                String constantReference = "constant(" + constant.name() + ")";
                Integer index = (Integer)this.nameToIndex.get((Object)constantReference);
                if (index == null) continue;
                this.values[index.intValue()] = new TensorValue(constant.value());
            }
            for (Map.Entry entry : referencedFunctions.entrySet()) {
                Integer index = (Integer)this.nameToIndex.get((Object)((FunctionReference)entry.getKey()).serialForm());
                if (index == null) continue;
                this.values[index.intValue()] = new LazyValue((FunctionReference)entry.getKey(), owner, model);
            }
        }

        private void setMissingValue(Tensor value) {
            this.missingValue = new TensorValue(value).freeze();
        }

        private void extractBindTargets(ExpressionNode node, Map<FunctionReference, ExpressionFunction> functions, Set<String> bindTargets, Set<String> arguments, List<OnnxModel> onnxModels, Map<String, OnnxModel> onnxModelsInUse) {
            if (this.isFunctionReference(node)) {
                FunctionReference reference = FunctionReference.fromSerial(node.toString()).get();
                bindTargets.add(reference.serialForm());
                ExpressionNode functionNode = functions.get(reference).getBody().getRoot();
                this.extractBindTargets(functionNode, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
            } else if (this.isOnnx(node)) {
                this.extractOnnxTargets(node, bindTargets, arguments, onnxModels, onnxModelsInUse);
            } else if (this.isConstant(node)) {
                bindTargets.add(node.toString());
            } else if (node instanceof ReferenceNode) {
                bindTargets.add(node.toString());
                arguments.add(node.toString());
            } else if (node instanceof CompositeNode) {
                CompositeNode cNode = (CompositeNode)node;
                for (ExpressionNode child : cNode.children()) {
                    this.extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
                }
            }
        }

        private void extractOnnxTargets(ExpressionNode node, Set<String> bindTargets, Set<String> arguments, List<OnnxModel> onnxModels, Map<String, OnnxModel> onnxModelsInUse) {
            Optional<String> modelName = this.getArgument(node);
            if (modelName.isPresent()) {
                for (OnnxModel onnxModel : onnxModels) {
                    if (!onnxModel.name().equals(modelName.get())) continue;
                    String onnxFeature = node.toString();
                    bindTargets.add(onnxFeature);
                    arguments.add(onnxFeature);
                    onnxModel.load();
                    for (String input : onnxModel.inputs().keySet()) {
                        bindTargets.add(input);
                        arguments.add(input);
                    }
                    onnxModelsInUse.put(onnxFeature, onnxModel);
                }
            }
        }

        private Optional<String> getArgument(ExpressionNode node) {
            ReferenceNode reference;
            if (node instanceof ReferenceNode && (reference = (ReferenceNode)node).getArguments().size() > 0) {
                if (reference.getArguments().expressions().get(0) instanceof ConstantNode) {
                    ConstantNode constantNode = (ConstantNode)reference.getArguments().expressions().get(0);
                    return Optional.of(IndexedBindings.stripQuotes(constantNode.sourceString()));
                }
                if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) {
                    ReferenceNode referenceNode = (ReferenceNode)reference.getArguments().expressions().get(0);
                    return Optional.of(referenceNode.getName());
                }
            }
            return Optional.empty();
        }

        public static String stripQuotes(String s) {
            if (s.codePointAt(0) == 34 && s.codePointAt(s.length() - 1) == 34) {
                return s.substring(1, s.length() - 1);
            }
            if (s.codePointAt(0) == 39 && s.codePointAt(s.length() - 1) == 39) {
                return s.substring(1, s.length() - 1);
            }
            return s;
        }

        private boolean isFunctionReference(ExpressionNode node) {
            if (!(node instanceof ReferenceNode)) {
                return false;
            }
            ReferenceNode reference = (ReferenceNode)node;
            return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1;
        }

        private boolean isOnnx(ExpressionNode node) {
            if (!(node instanceof ReferenceNode)) {
                return false;
            }
            ReferenceNode reference = (ReferenceNode)node;
            return reference.getName().equals("onnx") || reference.getName().equals("onnxModel");
        }

        private boolean isConstant(ExpressionNode node) {
            if (!(node instanceof ReferenceNode)) {
                return false;
            }
            ReferenceNode reference = (ReferenceNode)node;
            return reference.getName().equals("constant") && reference.getArguments().size() == 1;
        }

        Value get(int index) {
            Value value = this.values[index];
            return value == missing ? this.missingValue : value;
        }

        void set(int index, Value value) {
            this.values[index] = value;
        }

        Set<String> names() {
            return this.nameToIndex.keySet();
        }

        Set<String> arguments() {
            return this.arguments;
        }

        Integer indexOf(String name) {
            return (Integer)this.nameToIndex.get((Object)name);
        }

        Map<String, OnnxModel> onnxModels() {
            return this.onnxModels;
        }

        IndexedBindings copy(Context context) {
            Value[] valueCopy = new Value[this.values.length];
            for (int i = 0; i < this.values.length; ++i) {
                valueCopy[i] = this.values[i] instanceof LazyValue ? ((LazyValue)this.values[i]).copyFor(context) : this.values[i];
            }
            return new IndexedBindings(this.nameToIndex, valueCopy, this.arguments, this.onnxModels);
        }
    }
}

