/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.transform;

import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.NameNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.Optional;

public class TensorMaxMinTransformer<CONTEXT extends TransformContext>
extends ExpressionTransformer<CONTEXT> {
    @Override
    public ExpressionNode transform(ExpressionNode node, CONTEXT context) {
        if (node instanceof CompositeNode) {
            node = this.transformChildren((CompositeNode)node, context);
        }
        if (node instanceof FunctionNode) {
            node = TensorMaxMinTransformer.transformFunctionNode((FunctionNode)node, ((TransformContext)context).types());
        }
        return node;
    }

    public static ExpressionNode transformFunctionNode(FunctionNode node, TypeContext<Reference> context) {
        switch (node.getFunction()) {
            case min: 
            case max: {
                return TensorMaxMinTransformer.transformMaxAndMinFunctionNode(node, context);
            }
        }
        return node;
    }

    private static ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node, TypeContext<Reference> context) {
        TensorType type;
        if (node.children().size() != 2) {
            return node;
        }
        ExpressionNode arg1 = node.children().get(0);
        Optional<String> dimension = TensorMaxMinTransformer.dimensionName(node.children().get(1));
        if (dimension.isPresent() && (type = arg1.type(context)).dimension(dimension.get()).isPresent()) {
            return TensorMaxMinTransformer.replaceMaxAndMinFunction(node);
        }
        return node;
    }

    private static Optional<String> dimensionName(ExpressionNode node) {
        if (node instanceof ReferenceNode) {
            Reference reference = ((ReferenceNode)node).reference();
            if (reference.isIdentifier()) {
                return Optional.of(reference.name());
            }
            return Optional.empty();
        }
        if (node instanceof NameNode) {
            return Optional.of(((NameNode)node).getValue());
        }
        return Optional.empty();
    }

    private static ExpressionNode replaceMaxAndMinFunction(FunctionNode node) {
        ExpressionNode arg1 = node.children().get(0);
        ExpressionNode arg2 = node.children().get(1);
        TensorFunctionNode.ExpressionTensorFunction expression = TensorFunctionNode.wrap(arg1);
        Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf((String)node.getFunction().name());
        String dimension = ((ReferenceNode)arg2).getName();
        return new TensorFunctionNode((TensorFunction<Reference>)new Reduce((TensorFunction)expression, aggregator, dimension));
    }
}

