/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class SimplifyCountOverConstant
implements Rule {
    @Override
    public Optional<PlanNode> apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) {
        if (!(node instanceof AggregationNode)) {
            return Optional.empty();
        }
        AggregationNode parent = (AggregationNode)node;
        PlanNode input = lookup.resolve(parent.getSource());
        if (!(input instanceof ProjectNode)) {
            return Optional.empty();
        }
        ProjectNode child = (ProjectNode)input;
        boolean changed = false;
        LinkedHashMap<Symbol, AggregationNode.Aggregation> assignments = new LinkedHashMap<Symbol, AggregationNode.Aggregation>(parent.getAssignments());
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : parent.getAssignments().entrySet()) {
            Symbol symbol = entry.getKey();
            AggregationNode.Aggregation aggregation = entry.getValue();
            if (!SimplifyCountOverConstant.isCountOverConstant(aggregation, child.getAssignments())) continue;
            changed = true;
            assignments.put(symbol, new AggregationNode.Aggregation(new FunctionCall(QualifiedName.of((String)"count"), (List)ImmutableList.of()), new Signature("count", FunctionKind.AGGREGATE, TypeSignature.parseTypeSignature((String)"bigint"), new TypeSignature[0])));
        }
        if (!changed) {
            return Optional.empty();
        }
        return Optional.of(new AggregationNode(node.getId(), child, assignments, parent.getGroupingSets(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol()));
    }

    private static boolean isCountOverConstant(AggregationNode.Aggregation aggregation, Assignments inputs) {
        Signature signature = aggregation.getSignature();
        if (!signature.getName().equals("count") || signature.getArgumentTypes().size() != 1) {
            return false;
        }
        Expression argument = (Expression)aggregation.getCall().getArguments().get(0);
        if (argument instanceof SymbolReference) {
            argument = inputs.get(Symbol.from(argument));
        }
        return argument instanceof Literal && !(argument instanceof NullLiteral);
    }
}

