/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchdefinition.expressiontransforms;

import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.NameNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.Optional;

public class TensorTransformer
extends ExpressionTransformer<RankProfileTransformContext> {
    public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
        if (node instanceof CompositeNode) {
            node = this.transformChildren((CompositeNode)node, context);
        }
        if (node instanceof FunctionNode) {
            node = this.transformFunctionNode((FunctionNode)node, context);
        }
        return node;
    }

    private ExpressionNode transformFunctionNode(FunctionNode node, RankProfileTransformContext context) {
        switch (node.getFunction()) {
            case min: 
            case max: {
                return this.transformMaxAndMinFunctionNode(node, context);
            }
        }
        return node;
    }

    private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node, RankProfileTransformContext context) {
        TensorType type;
        if (node.children().size() != 2) {
            return node;
        }
        ExpressionNode arg1 = (ExpressionNode)node.children().get(0);
        Optional<String> dimension = this.dimensionName((ExpressionNode)node.children().get(1));
        if (dimension.isPresent() && (type = arg1.type((TypeContext)context.rankProfile().typeContext(context.queryProfiles()))).dimension(dimension.get()).isPresent()) {
            return this.replaceMaxAndMinFunction(node);
        }
        return node;
    }

    private Optional<String> dimensionName(ExpressionNode node) {
        if (node instanceof ReferenceNode) {
            Reference reference = ((ReferenceNode)node).reference();
            if (reference.isIdentifier()) {
                return Optional.of(reference.name());
            }
            return Optional.empty();
        }
        if (node instanceof NameNode) {
            return Optional.of(((NameNode)node).getValue());
        }
        return Optional.empty();
    }

    private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) {
        ExpressionNode arg1 = (ExpressionNode)node.children().get(0);
        ExpressionNode arg2 = (ExpressionNode)node.children().get(1);
        TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument((ExpressionNode)arg1);
        Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf((String)node.getFunction().name());
        String dimension = ((ReferenceNode)arg2).getName();
        return new TensorFunctionNode((TensorFunction)new Reduce((TensorFunction)expression, aggregator, dimension));
    }
}

