/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.SystemSessionProperties;
import io.trino.cost.TaskCountEstimator;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.DistinctAggregationStrategyChooser;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.MarkDistinctNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

public class MultipleDistinctAggregationToMarkDistinct
implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching((Predicate)Predicates.and(MultipleDistinctAggregationToMarkDistinct::hasNoDistinctWithFilterOrMask, (com.google.common.base.Predicate)Predicates.or(MultipleDistinctAggregationToMarkDistinct::hasMultipleDistincts, MultipleDistinctAggregationToMarkDistinct::hasMixedDistinctAndNonDistincts)));
    private final DistinctAggregationStrategyChooser distinctAggregationStrategyChooser;

    private static boolean hasNoDistinctWithFilterOrMask(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().noneMatch(aggregation -> aggregation.isDistinct() && (aggregation.getFilter().isPresent() || aggregation.getMask().isPresent()));
    }

    private static boolean hasMultipleDistincts(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).map(AggregationNode.Aggregation::getArguments).map(HashSet::new).distinct().count() > 1L;
    }

    private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregationNode) {
        long distincts = aggregationNode.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).count();
        return distincts > 0L && distincts < (long)aggregationNode.getAggregations().size();
    }

    public MultipleDistinctAggregationToMarkDistinct(TaskCountEstimator taskCountEstimator) {
        this.distinctAggregationStrategyChooser = DistinctAggregationStrategyChooser.createDistinctAggregationStrategyChooser(taskCountEstimator);
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode parent, Captures captures, Rule.Context context) {
        OptimizerConfig.DistinctAggregationsStrategy distinctAggregationsStrategy = SystemSessionProperties.distinctAggregationsStrategy(context.getSession());
        if (!(distinctAggregationsStrategy.equals((Object)OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT) || distinctAggregationsStrategy.equals((Object)OptimizerConfig.DistinctAggregationsStrategy.AUTOMATIC) && this.distinctAggregationStrategyChooser.shouldAddMarkDistinct(parent, context.getSession(), context.getStatsProvider()))) {
            return Rule.Result.empty();
        }
        HashMap markers = new HashMap();
        HashMap<Symbol, AggregationNode.Aggregation> newAggregations = new HashMap<Symbol, AggregationNode.Aggregation>();
        PlanNode subPlan = parent.getSource();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = entry.getValue();
            if (aggregation.isDistinct() && aggregation.getFilter().isEmpty() && aggregation.getMask().isEmpty()) {
                Set inputs = aggregation.getArguments().stream().map(Symbol::from).collect(Collectors.toSet());
                Symbol marker = (Symbol)markers.get(inputs);
                if (marker == null) {
                    marker = context.getSymbolAllocator().newSymbol(((Symbol)Iterables.getLast(inputs)).name() + "_distinct", (Type)BooleanType.BOOLEAN);
                    markers.put(inputs, marker);
                    ImmutableSet.Builder distinctSymbols = ImmutableSet.builder().addAll(parent.getGroupingKeys()).addAll(inputs);
                    parent.getGroupIdSymbol().ifPresent(arg_0 -> ((ImmutableSet.Builder)distinctSymbols).add(arg_0));
                    subPlan = new MarkDistinctNode(context.getIdAllocator().getNextId(), subPlan, marker, (List<Symbol>)ImmutableList.copyOf((Collection)distinctSymbols.build()), Optional.empty());
                }
                newAggregations.put(entry.getKey(), new AggregationNode.Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), false, aggregation.getFilter(), aggregation.getOrderingScheme(), Optional.of(marker)));
                continue;
            }
            newAggregations.put(entry.getKey(), aggregation);
        }
        return Rule.Result.ofPlanNode(AggregationNode.builderFrom(parent).setSource(subPlan).setAggregations(newAggregations).setPreGroupedSymbols((List<Symbol>)ImmutableList.of()).build());
    }
}

