/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.tree.FunctionCall;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

public class MultipleDistinctAggregationToMarkDistinct
implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching((Predicate)Predicates.and(MultipleDistinctAggregationToMarkDistinct::hasNoDistinctWithFilterOrMask, (com.google.common.base.Predicate)Predicates.or(MultipleDistinctAggregationToMarkDistinct::hasMultipleDistincts, MultipleDistinctAggregationToMarkDistinct::hasMixedDistinctAndNonDistincts)));

    private static boolean hasNoDistinctWithFilterOrMask(AggregationNode aggregation) {
        return aggregation.getAggregations().values().stream().noneMatch(e -> e.getCall().isDistinct() && (e.getCall().getFilter().isPresent() || e.getMask().isPresent()));
    }

    private static boolean hasMultipleDistincts(AggregationNode aggregation) {
        return aggregation.getAggregations().values().stream().filter(e -> e.getCall().isDistinct()).map(AggregationNode.Aggregation::getCall).map(FunctionCall::getArguments).map(HashSet::new).distinct().count() > 1L;
    }

    private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregation) {
        long distincts = aggregation.getAggregations().values().stream().map(AggregationNode.Aggregation::getCall).filter(FunctionCall::isDistinct).count();
        return distincts > 0L && distincts < (long)aggregation.getAggregations().size();
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode parent, Captures captures, Rule.Context context) {
        if (!SystemSessionProperties.useMarkDistinct(context.getSession())) {
            return Rule.Result.empty();
        }
        HashMap markers = new HashMap();
        HashMap<Symbol, AggregationNode.Aggregation> newAggregations = new HashMap<Symbol, AggregationNode.Aggregation>();
        PlanNode subPlan = parent.getSource();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = entry.getValue();
            FunctionCall call = aggregation.getCall();
            if (call.isDistinct() && !call.getFilter().isPresent() && !aggregation.getMask().isPresent()) {
                Set inputs = call.getArguments().stream().map(Symbol::from).collect(Collectors.toSet());
                Symbol marker = (Symbol)markers.get(inputs);
                if (marker == null) {
                    marker = context.getSymbolAllocator().newSymbol(((Symbol)Iterables.getLast(inputs)).getName(), (Type)BooleanType.BOOLEAN, "distinct");
                    markers.put(inputs, marker);
                    ImmutableSet.Builder distinctSymbols = ImmutableSet.builder().addAll(parent.getGroupingKeys()).addAll(inputs);
                    parent.getGroupIdSymbol().ifPresent(arg_0 -> ((ImmutableSet.Builder)distinctSymbols).add(arg_0));
                    subPlan = new MarkDistinctNode(context.getIdAllocator().getNextId(), subPlan, marker, (List<Symbol>)ImmutableList.copyOf((Collection)distinctSymbols.build()), Optional.empty());
                }
                newAggregations.put(entry.getKey(), new AggregationNode.Aggregation(new FunctionCall(call.getName(), call.getWindow(), call.getFilter(), call.getOrderBy(), false, call.getArguments()), aggregation.getFunctionHandle(), Optional.of(marker)));
                continue;
            }
            newAggregations.put(entry.getKey(), aggregation);
        }
        return Rule.Result.ofPlanNode(new AggregationNode(parent.getId(), subPlan, newAggregations, parent.getGroupingSets(), (List<Symbol>)ImmutableList.of(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol()));
    }
}

