/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.schema;

import com.google.common.collect.ImmutableMap;
import com.yahoo.schema.FeatureNames;
import com.yahoo.schema.expressiontransforms.OnnxModelTransformer;
import com.yahoo.schema.expressiontransforms.TokenTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext;
import com.yahoo.searchlib.rankingexpression.rule.NameNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;

public class MapEvaluationTypeContext
extends FunctionReferenceContext
implements TypeContext<Reference> {
    private final Optional<MapEvaluationTypeContext> parent;
    private final Map<Reference, TensorType> featureTypes = new HashMap<Reference, TensorType>();
    private final Map<Reference, TensorType> resolvedTypes = new HashMap<Reference, TensorType>();
    private final Map<Reference, TensorType> globallyResolvedTypes;
    private final Deque<Reference> currentResolutionCallStack;
    private final SortedSet<Reference> queryFeaturesNotDeclared;
    private boolean tensorsAreUsed;

    MapEvaluationTypeContext(ImmutableMap<String, ExpressionFunction> functions, Map<Reference, TensorType> featureTypes) {
        super(functions);
        this.parent = Optional.empty();
        this.featureTypes.putAll(featureTypes);
        this.currentResolutionCallStack = new ArrayDeque<Reference>();
        this.queryFeaturesNotDeclared = new TreeSet<Reference>();
        this.tensorsAreUsed = false;
        this.globallyResolvedTypes = new HashMap<Reference, TensorType>();
    }

    private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings, Optional<MapEvaluationTypeContext> parent, Map<Reference, TensorType> featureTypes, Deque<Reference> currentResolutionCallStack, SortedSet<Reference> queryFeaturesNotDeclared, boolean tensorsAreUsed, Map<Reference, TensorType> globallyResolvedTypes) {
        super(functions, bindings);
        this.parent = parent;
        this.featureTypes.putAll(featureTypes);
        this.currentResolutionCallStack = currentResolutionCallStack;
        this.queryFeaturesNotDeclared = queryFeaturesNotDeclared;
        this.tensorsAreUsed = tensorsAreUsed;
        this.globallyResolvedTypes = globallyResolvedTypes;
    }

    public void setType(Reference reference, TensorType type) {
        this.featureTypes.put(reference, type);
        this.queryFeaturesNotDeclared.remove(reference);
    }

    public Map<Reference, TensorType> featureTypes() {
        return Collections.unmodifiableMap(this.featureTypes);
    }

    public TensorType getType(String reference) {
        throw new UnsupportedOperationException("Not able to parse general references from string form");
    }

    public void forgetResolvedTypes() {
        this.resolvedTypes.clear();
    }

    private boolean referenceCanBeResolvedGlobally(Reference reference) {
        Optional<ExpressionFunction> function = this.functionInvocation(reference);
        return function.isPresent() && function.get().arguments().size() == 0;
    }

    public TensorType getType(Reference reference) {
        boolean canBeResolvedGlobally = this.referenceCanBeResolvedGlobally(reference);
        TensorType resolvedType = this.resolvedTypes.get(reference);
        if (resolvedType == null && canBeResolvedGlobally) {
            resolvedType = this.globallyResolvedTypes.get(reference);
        }
        if (resolvedType != null) {
            return resolvedType;
        }
        resolvedType = this.resolveType(reference);
        if (resolvedType == null) {
            return this.defaultTypeOf(reference);
        }
        this.resolvedTypes.put(reference, resolvedType);
        if (resolvedType.rank() > 0) {
            this.tensorsAreUsed = true;
        }
        if (canBeResolvedGlobally) {
            this.globallyResolvedTypes.put(reference, resolvedType);
        }
        return resolvedType;
    }

    MapEvaluationTypeContext getParent(String forArgument, String boundTo) {
        return this.parent.orElseThrow(() -> new IllegalArgumentException("argument " + forArgument + " is bound to " + boundTo + " but there is no parent context"));
    }

    public String resolveBinding(String name) {
        String bound = this.getBinding(name);
        if (bound == null) {
            return name;
        }
        return this.getParent(name, bound).resolveBinding(bound);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private TensorType resolveType(Reference reference) {
        if (this.currentResolutionCallStack.contains(reference)) {
            throw new IllegalArgumentException("Invocation loop: " + this.currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) + " -> " + reference);
        }
        Optional<String> binding = this.boundIdentifier(reference);
        if (binding.isPresent()) {
            try {
                RankingExpression expr = new RankingExpression(binding.get());
                return expr.type((TypeContext)this.getParent(reference.name(), binding.get()));
            }
            catch (ParseException e) {
                throw new IllegalArgumentException(e);
            }
        }
        try {
            Reference asConst;
            this.currentResolutionCallStack.addLast(reference);
            if (FeatureNames.isSimpleFeature(reference)) {
                String argument = (String)reference.simpleArgument().get();
                String argumentBinding = this.resolveBinding(argument);
                reference = Reference.simple((String)reference.name(), (String)argumentBinding);
                TensorType tensorType = this.featureTypes.get(reference);
                return tensorType;
            }
            Optional<ExpressionFunction> function = this.functionInvocation(reference);
            if (function.isPresent()) {
                RankingExpression body = function.get().getBody();
                FunctionReferenceContext child = this.withBindings((Map)this.bind(function.get().arguments(), reference.arguments()));
                TensorType tensorType = body.type((TypeContext)child);
                return tensorType;
            }
            Optional<TensorType> onnxFeatureType = this.onnxFeatureType(reference);
            if (onnxFeatureType.isPresent()) {
                TensorType child = onnxFeatureType.get();
                return child;
            }
            Optional<TensorType> transformerTokensFeatureType = this.transformerTokensFeatureType(reference);
            if (transformerTokensFeatureType.isPresent()) {
                TensorType tensorType = transformerTokensFeatureType.get();
                return tensorType;
            }
            Optional<TensorType> featureTensorType = this.tensorFeatureType(reference);
            if (featureTensorType.isPresent()) {
                TensorType tensorType = featureTensorType.get();
                return tensorType;
            }
            if (reference.isIdentifier() && this.featureTypes.containsKey(reference)) {
                TensorType tensorType = this.featureTypes.get(reference);
                return tensorType;
            }
            if (reference.isIdentifier() && this.featureTypes.containsKey(asConst = FeatureNames.asConstantFeature(reference.name()))) {
                TensorType tensorType = this.featureTypes.get(asConst);
                return tensorType;
            }
            TensorType tensorType = TensorType.empty;
            return tensorType;
        }
        finally {
            this.currentResolutionCallStack.removeLast();
        }
    }

    public TensorType defaultTypeOf(Reference reference) {
        if (!FeatureNames.isSimpleFeature(reference)) {
            throw new IllegalArgumentException("This can only be called for simple references, not " + reference);
        }
        if (reference.name().equals("query")) {
            this.queryFeaturesNotDeclared.add(reference);
            return TensorType.empty;
        }
        return null;
    }

    private Optional<String> boundIdentifier(Reference reference) {
        if (!reference.arguments().isEmpty()) {
            return Optional.empty();
        }
        if (reference.output() != null) {
            return Optional.empty();
        }
        return Optional.ofNullable(this.getBinding(reference.name()));
    }

    private Optional<ExpressionFunction> functionInvocation(Reference reference) {
        if (reference.output() != null) {
            return Optional.empty();
        }
        ExpressionFunction function = (ExpressionFunction)this.getFunctions().get(reference.name());
        if (function == null) {
            return Optional.empty();
        }
        if (function.arguments().size() != reference.arguments().size()) {
            return Optional.empty();
        }
        return Optional.of(function);
    }

    private Optional<TensorType> onnxFeatureType(Reference reference) {
        if (!reference.name().equals("onnxModel") && !reference.name().equals("onnx")) {
            return Optional.empty();
        }
        if (!this.featureTypes.containsKey(reference)) {
            String modelOutput;
            String configOrFileName = ((ExpressionNode)reference.arguments().expressions().get(0)).toString();
            String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
            if (!this.featureTypes.containsKey(reference = new Reference("onnx", new Arguments((ExpressionNode)new ReferenceNode(modelConfigName)), modelOutput = OnnxModelTransformer.getModelOutput(reference, null)))) {
                throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
            }
        }
        return Optional.of(this.featureTypes.get(reference));
    }

    private Optional<TensorType> transformerTokensFeatureType(Reference reference) {
        if (!(reference.name().equals("tokenTypeIds") || reference.name().equals("tokenInputIds") || reference.name().equals("tokenAttentionMask"))) {
            return Optional.empty();
        }
        if (reference.arguments().size() <= 1) {
            throw new IllegalArgumentException(reference.name() + " must have at least 2 arguments");
        }
        ExpressionNode size = (ExpressionNode)reference.arguments().expressions().get(0);
        return Optional.of(TokenTransformer.createTensorType(reference.name(), size));
    }

    private Optional<TensorType> tensorFeatureType(Reference reference) {
        String dimension;
        if (!(reference.name().equals("tensorFromLabels") || reference.name().equals("tensorFromWeightedSet") || reference.name().equals("closest"))) {
            return Optional.empty();
        }
        if (reference.arguments().size() != 1 && reference.arguments().size() != 2) {
            throw new IllegalArgumentException(reference.name() + " must have one or two arguments");
        }
        ExpressionNode arg0 = (ExpressionNode)reference.arguments().expressions().get(0);
        if (reference.name().equals("closest")) {
            Reference attrFeature;
            TensorType attrTT;
            ReferenceNode argRefNode;
            Reference argRef;
            if (arg0 instanceof ReferenceNode && (argRef = (argRefNode = (ReferenceNode)arg0).reference()).isIdentifier() && (attrTT = this.featureTypes.get(attrFeature = FeatureNames.asAttributeFeature(argRef.name()))) != null && attrTT.rank() > 0) {
                TensorType mapped = attrTT.mappedSubtype();
                if (mapped.rank() > 0) {
                    return Optional.of(mapped);
                }
                throw new IllegalArgumentException("Unexpected tensor type " + attrTT + " for " + attrFeature + " used by " + reference);
            }
            throw new IllegalArgumentException("The first argument of " + reference.name() + " must be the name of a tensor attribute, not " + arg0);
        }
        if (!(arg0 instanceof ReferenceNode) || !FeatureNames.isSimpleFeature(((ReferenceNode)arg0).reference())) {
            throw new IllegalArgumentException("The first argument of " + reference.name() + " must be a simple feature, not " + arg0);
        }
        if (reference.arguments().size() > 1) {
            ExpressionNode arg1 = (ExpressionNode)reference.arguments().expressions().get(1);
            if (!(arg1 instanceof ReferenceNode && ((ReferenceNode)arg1).reference().isIdentifier() || arg1 instanceof NameNode)) {
                throw new IllegalArgumentException("The second argument of " + reference.name() + " must be a dimension name, not " + arg1);
            }
            dimension = ((ExpressionNode)reference.arguments().expressions().get(1)).toString();
        } else {
            dimension = ((ExpressionNode)((ReferenceNode)arg0).reference().arguments().expressions().get(0)).toString();
        }
        return Optional.of(new TensorType.Builder().mapped(dimension).build());
    }

    private Map<String, String> bind(List<String> formalArguments, Arguments invocationArguments) {
        HashMap<String, String> bindings = new HashMap<String, String>(formalArguments.size());
        for (int i = 0; i < formalArguments.size(); ++i) {
            String identifier = ((ExpressionNode)invocationArguments.expressions().get(i)).toString();
            bindings.put(formalArguments.get(i), identifier);
        }
        return bindings;
    }

    public SortedSet<Reference> queryFeaturesNotDeclared() {
        return Collections.unmodifiableSortedSet(this.queryFeaturesNotDeclared);
    }

    public boolean tensorsAreUsed() {
        return this.tensorsAreUsed;
    }

    public MapEvaluationTypeContext withBindings(Map<String, String> bindings) {
        return new MapEvaluationTypeContext(this.getFunctions(), bindings, Optional.of(this), this.featureTypes, this.currentResolutionCallStack, this.queryFeaturesNotDeclared, this.tensorsAreUsed, this.globallyResolvedTypes);
    }
}

