/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.schema.expressiontransforms;

import com.yahoo.schema.FeatureNames;
import com.yahoo.schema.OnnxModel;
import com.yahoo.schema.RankProfile;
import com.yahoo.schema.expressiontransforms.InputRecorderContext;
import com.yahoo.schema.expressiontransforms.OnnxModelTransformer;
import com.yahoo.schema.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.TensorFunction;
import java.io.Reader;
import java.io.StringReader;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

public class InputRecorder
extends ExpressionTransformer<InputRecorderContext> {
    private static final Logger log = Logger.getLogger(InputRecorder.class.getName());
    private final Set<String> neededInputs;
    private final Set<String> handled = new HashSet<String>();
    private final Set<String> availableNormalizers = new HashSet<String>();
    private final Set<String> usedNormalizers = new HashSet<String>();

    public InputRecorder(Set<String> target) {
        this.neededInputs = target;
    }

    public void process(RankingExpression expression, RankProfileTransformContext context) {
        this.process(expression.getRoot(), context);
    }

    public void process(ExpressionNode node, RankProfileTransformContext context) {
        this.transform(node, new InputRecorderContext(context));
    }

    public void alreadyMatchFeatures(Collection<String> matchFeatures) {
        for (String mf : matchFeatures) {
            this.handled.add(mf);
        }
    }

    public void addKnownNormalizers(Collection<String> names) {
        for (String name : names) {
            this.availableNormalizers.add(name);
        }
    }

    public Set<String> normalizersUsed() {
        return this.usedNormalizers;
    }

    public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) {
        if (node instanceof ReferenceNode) {
            ReferenceNode r = (ReferenceNode)node;
            this.handle(r, context);
            return node;
        }
        if (node instanceof TensorFunctionNode) {
            TensorFunctionNode t = (TensorFunctionNode)node;
            TensorFunction f = t.function();
            if (f instanceof Generate) {
                InputRecorderContext childContext = new InputRecorderContext(context);
                TensorType tt = f.type(context.types());
                for (TensorType.Dimension dim : tt.dimensions()) {
                    childContext.localVariables().add(dim.name());
                }
                return this.transformChildren((CompositeNode)t, childContext);
            }
            node = t.withTransformedExpressions(expr -> this.transform((ExpressionNode)expr, context));
        }
        if (node instanceof CompositeNode) {
            CompositeNode c = (CompositeNode)node;
            return this.transformChildren(c, context);
        }
        if (node instanceof ConstantNode) {
            return node;
        }
        throw new IllegalArgumentException("Cannot handle node type: " + node + " [" + node.getClass() + "]");
    }

    private void handle(ReferenceNode feature, InputRecorderContext context) {
        boolean simpleFunctionOrIdentifier;
        Reference ref = feature.reference();
        String name = ref.name();
        Arguments args = ref.arguments();
        boolean bl = simpleFunctionOrIdentifier = args.size() == 0 && ref.output() == null;
        if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) {
            return;
        }
        if (simpleFunctionOrIdentifier && this.availableNormalizers.contains(name)) {
            this.usedNormalizers.add(name);
            return;
        }
        if (ref.isSimpleRankingExpressionWrapper()) {
            name = (String)ref.simpleArgument().get();
            simpleFunctionOrIdentifier = true;
        }
        if (simpleFunctionOrIdentifier) {
            if (this.handled.contains(name)) {
                return;
            }
            RankProfile.RankingExpressionFunction f = context.rankProfile().getFunctions().get(name);
            if (f != null && f.function().arguments().size() == 0) {
                this.transform(f.function().getBody().getRoot(), context);
                this.handled.add(name);
                return;
            }
            this.neededInputs.add(feature.toString());
            return;
        }
        if (FeatureNames.isSimpleFeature(ref)) {
            if (FeatureNames.isAttributeFeature(ref)) {
                this.neededInputs.add(feature.toString());
                return;
            }
            if (FeatureNames.isQueryFeature(ref)) {
                return;
            }
            if (FeatureNames.isConstantFeature(ref)) {
                Map<Reference, RankProfile.Constant> allConstants = context.rankProfile().constants();
                if (allConstants.containsKey(ref)) {
                    return;
                }
                throw new IllegalArgumentException("unknown constant: " + feature);
            }
        }
        if ("onnx".equals(name)) {
            ExpressionNode tmp;
            if (args.size() < 1) {
                throw new IllegalArgumentException("expected name of ONNX model as argument: " + feature);
            }
            ExpressionNode arg = (ExpressionNode)args.expressions().get(0);
            Map<String, OnnxModel> models = context.rankProfile().onnxModels();
            OnnxModel model = models.get(arg.toString());
            if (model == null && (tmp = OnnxModelTransformer.transformFeature(feature, context.rankProfile())) instanceof ReferenceNode) {
                ReferenceNode newRefNode = (ReferenceNode)tmp;
                args = newRefNode.getArguments();
                arg = (ExpressionNode)args.expressions().get(0);
                model = models.get(arg.toString());
            }
            if (model == null) {
                throw new IllegalArgumentException("missing onnx model: " + arg);
            }
            model.getInputMap().forEach((__, onnxInput) -> {
                StringReader reader = new StringReader((String)onnxInput);
                try {
                    RankingExpression asExpression = new RankingExpression((Reader)reader);
                    this.transform(asExpression.getRoot(), context);
                }
                catch (ParseException e) {
                    throw new IllegalArgumentException("illegal onnx input '" + onnxInput + "': " + e.getMessage());
                }
            });
            return;
        }
        this.neededInputs.add(feature.toString());
    }
}

