/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchdefinition.expressiontransforms;

import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.document.Attribute;
import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
import java.util.Optional;

public class TensorTransformer
extends ExpressionTransformer<RankProfileTransformContext> {
    public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
        if (node instanceof CompositeNode) {
            node = this.transformChildren((CompositeNode)node, context);
        }
        if (node instanceof FunctionNode) {
            node = this.transformFunctionNode((FunctionNode)node, context.rankProfile());
        }
        return node;
    }

    private ExpressionNode transformFunctionNode(FunctionNode node, RankProfile profile) {
        switch (node.getFunction()) {
            case min: 
            case max: {
                return this.transformMaxAndMinFunctionNode(node, profile);
            }
        }
        return node;
    }

    private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node, RankProfile profile) {
        if (node.children().size() != 2) {
            return node;
        }
        ExpressionNode arg1 = (ExpressionNode)node.children().get(0);
        Optional<String> dimension = this.dimensionName((ExpressionNode)node.children().get(1));
        if (dimension.isPresent()) {
            try {
                Context context = this.buildContext(arg1, profile);
                Value value = arg1.evaluate(context);
                if (this.isTensorWithDimension(value, dimension.get())) {
                    return this.replaceMaxAndMinFunction(node);
                }
            }
            catch (IllegalArgumentException illegalArgumentException) {
                // empty catch block
            }
        }
        return node;
    }

    private Optional<String> dimensionName(ExpressionNode arg) {
        if (arg instanceof ReferenceNode && ((ReferenceNode)arg).children().size() == 0) {
            return Optional.of(((ReferenceNode)arg).getName());
        }
        return Optional.empty();
    }

    private boolean isTensorWithDimension(Value value, String dimension) {
        if (value instanceof TensorValue) {
            return value.asTensor().type().dimensionNames().contains(dimension);
        }
        return false;
    }

    private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) {
        ExpressionNode arg1 = (ExpressionNode)node.children().get(0);
        ExpressionNode arg2 = (ExpressionNode)node.children().get(1);
        TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument((ExpressionNode)arg1);
        Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf((String)node.getFunction().name());
        String dimension = ((ReferenceNode)arg2).getName();
        return new TensorFunctionNode((TensorFunction)new Reduce((TensorFunction)expression, aggregator, dimension));
    }

    private Context buildContext(ExpressionNode node, RankProfile profile) {
        MapContext context = new MapContext();
        this.addRoot(node, (Context)context, profile);
        return context;
    }

    private Value emptyStringValue() {
        return new StringValue("");
    }

    private Value emptyDoubleValue() {
        return new DoubleValue(0.0);
    }

    private Value emptyTensorValue(TensorType type) {
        Tensor empty = Tensor.Builder.of((TensorType)type).build();
        return new TensorValue(empty);
    }

    private void addRoot(ExpressionNode node, Context context, RankProfile profile) {
        this.addChildren(node, context, profile);
        if (node instanceof ReferenceNode) {
            ReferenceNode referenceNode = (ReferenceNode)node;
            this.addIfAttribute(referenceNode, context, profile);
            this.addIfConstant(referenceNode, context, profile);
            this.addIfQuery(referenceNode, context, profile);
            this.addIfTensorFrom(referenceNode, context);
            this.addIfMacro(referenceNode, context, profile);
        }
    }

    private void addChildren(ExpressionNode node, Context context, RankProfile profile) {
        if (node instanceof CompositeNode) {
            List children = ((CompositeNode)node).children();
            for (ExpressionNode child : children) {
                this.addRoot(child, context, profile);
            }
        }
    }

    private void addIfAttribute(ReferenceNode node, Context context, RankProfile profile) {
        if (!node.getName().equals("attribute")) {
            return;
        }
        if (node.children().size() == 0) {
            return;
        }
        String attribute = ((ExpressionNode)node.children().get(0)).toString();
        Attribute a = profile.getSearch().getAttribute(attribute);
        if (a == null) {
            return;
        }
        Value v = a.getType() == Attribute.Type.STRING ? this.emptyStringValue() : (a.getType() == Attribute.Type.TENSOR ? this.emptyTensorValue(a.tensorType().orElseThrow(RuntimeException::new)) : this.emptyDoubleValue());
        context.put(node.toString(), v);
    }

    private void addIfConstant(ReferenceNode node, Context context, RankProfile profile) {
        if (!node.getName().equals("constant")) {
            return;
        }
        if (node.children().size() != 1) {
            return;
        }
        ExpressionNode child = (ExpressionNode)node.children().get(0);
        while (child instanceof CompositeNode && ((CompositeNode)child).children().size() > 0) {
            child = (ExpressionNode)((CompositeNode)child).children().get(0);
        }
        String name = child.toString();
        this.addIfConstantInRankProfile(name, node, context, profile);
        this.addIfConstantInRankingConstants(name, node, context, profile);
    }

    private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context, RankProfile profile) {
        if (profile.getConstants().containsKey(name)) {
            context.put(node.toString(), profile.getConstants().get(name));
        }
    }

    private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context, RankProfile profile) {
        RankingConstant constant = profile.getSearch().getRankingConstants().get(name);
        if (constant != null) {
            context.put(node.toString(), this.emptyTensorValue(constant.getTensorType()));
        }
    }

    private void addIfQuery(ReferenceNode node, Context context, RankProfile profile) {
        if (!node.getName().equals("query")) {
            return;
        }
        if (node.children().size() != 1) {
            return;
        }
        String name = ((ExpressionNode)node.children().get(0)).toString();
        if (profile.getQueryFeatureTypes().containsKey(name)) {
            String type = profile.getQueryFeatureTypes().get(name);
            Value v = type.contains("tensor") ? this.emptyTensorValue(TensorType.fromSpec((String)type)) : (type.equalsIgnoreCase("string") ? this.emptyStringValue() : this.emptyDoubleValue());
            context.put(node.toString(), v);
        }
    }

    private void addIfTensorFrom(ReferenceNode node, Context context) {
        if (!node.getName().startsWith("tensorFrom")) {
            return;
        }
        if (node.children().size() < 1 || node.children().size() > 2) {
            return;
        }
        ExpressionNode source = (ExpressionNode)node.children().get(0);
        if (source instanceof CompositeNode && ((CompositeNode)source).children().size() > 0) {
            source = (ExpressionNode)((CompositeNode)source).children().get(0);
        }
        String dimension = source.toString();
        if (node.children().size() == 2) {
            dimension = ((ExpressionNode)node.children().get(1)).toString();
        }
        TensorType type = new TensorType.Builder().mapped(dimension).build();
        context.put(node.toString(), this.emptyTensorValue(type));
    }

    private void addIfMacro(ReferenceNode node, Context context, RankProfile profile) {
        RankProfile.Macro macro = profile.getMacros().get(node.getName());
        if (macro == null) {
            return;
        }
        ExpressionNode root = macro.getRankingExpression().getRoot();
        Context macroContext = this.buildContext(root, profile);
        this.addMacroArguments(node, context, macro, macroContext);
        Value value = root.evaluate(macroContext);
        context.put(node.toString(), value);
    }

    private void addMacroArguments(ReferenceNode node, Context context, RankProfile.Macro macro, Context macroContext) {
        if (macro.getFormalParams().size() > 0 && node.children().size() > 0) {
            for (int i = 0; i < macro.getFormalParams().size() && i < node.children().size(); ++i) {
                String param = macro.getFormalParams().get(i);
                ExpressionNode argumentExpression = (ExpressionNode)node.children().get(i);
                Value arg = argumentExpression.evaluate(context);
                macroContext.put(param, arg);
            }
        }
    }
}

