/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.TaskCountEstimator;
import com.facebook.presto.cost.VariableStatsEstimate;
import com.facebook.presto.execution.TaskManagerConfig;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Partitioning;
import com.facebook.presto.spi.plan.PartitioningHandle;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties;
import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.ProjectNodeUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultiset;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multiset;
import io.airlift.units.DataSize;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class AddExchangesBelowPartialAggregationOverGroupIdRuleSet {
    private static final Capture<ProjectNode> PROJECTION = Capture.newCapture();
    private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
    private static final Capture<GroupIdNode> GROUP_ID = Capture.newCapture();
    private static final Capture<ExchangeNode> REMOTE_EXCHANGE = Capture.newCapture();
    private static final Pattern<ExchangeNode> WITH_PROJECTION = Pattern.typeOf(ExchangeNode.class).matching(e -> e.getScope().isRemote()).capturedAs(REMOTE_EXCHANGE).with(Patterns.source().matching(Pattern.typeOf(ProjectNode.class).matching(ProjectNodeUtils::isIdentity).capturedAs(PROJECTION).with(Patterns.source().matching(Pattern.typeOf(AggregationNode.class).capturedAs(AGGREGATION).with(Patterns.Aggregation.step().equalTo((Object)AggregationNode.Step.PARTIAL)).with(Pattern.nonEmpty(Patterns.Aggregation.groupingColumns())).with(Patterns.source().matching(Pattern.typeOf(GroupIdNode.class).capturedAs(GROUP_ID)))))));
    private static final Pattern<ExchangeNode> WITHOUT_PROJECTION = Pattern.typeOf(ExchangeNode.class).matching(e -> e.getScope().isRemote()).capturedAs(REMOTE_EXCHANGE).with(Patterns.source().matching(Pattern.typeOf(AggregationNode.class).capturedAs(AGGREGATION).with(Patterns.Aggregation.step().equalTo((Object)AggregationNode.Step.PARTIAL)).with(Pattern.nonEmpty(Patterns.Aggregation.groupingColumns())).with(Patterns.source().matching(Pattern.typeOf(GroupIdNode.class).capturedAs(GROUP_ID)))));
    private static final double GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY = 0.5;
    private static final double ANTI_SKEWNESS_MARGIN = 3.0;
    private final TaskCountEstimator taskCountEstimator;
    private final DataSize maxPartialAggregationMemoryUsage;
    private final Metadata metadata;
    private final boolean nativeExecution;

    public AddExchangesBelowPartialAggregationOverGroupIdRuleSet(TaskCountEstimator taskCountEstimator, TaskManagerConfig taskManagerConfig, Metadata metadata, boolean nativeExecution) {
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
        this.maxPartialAggregationMemoryUsage = taskManagerConfig.getMaxPartialAggregationMemoryUsage();
        this.metadata = metadata;
        this.nativeExecution = nativeExecution;
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of((Object)this.belowProjectionRule(), (Object)this.belowExchangeRule());
    }

    @VisibleForTesting
    AddExchangesBelowExchangePartialAggregationGroupId belowExchangeRule() {
        return new AddExchangesBelowExchangePartialAggregationGroupId();
    }

    @VisibleForTesting
    AddExchangesBelowProjectionPartialAggregationGroupId belowProjectionRule() {
        return new AddExchangesBelowProjectionPartialAggregationGroupId();
    }

    @VisibleForTesting
    class AddExchangesBelowProjectionPartialAggregationGroupId
    extends BaseAddExchangesBelowExchangePartialAggregationGroupId {
        AddExchangesBelowProjectionPartialAggregationGroupId() {
        }

        @Override
        public Pattern<ExchangeNode> getPattern() {
            return WITH_PROJECTION;
        }

        @Override
        public Rule.Result apply(ExchangeNode exchange, Captures captures, Rule.Context context) {
            ProjectNode project = (ProjectNode)captures.get(PROJECTION);
            AggregationNode aggregation = (AggregationNode)captures.get(AGGREGATION);
            GroupIdNode groupId = (GroupIdNode)((Object)captures.get(GROUP_ID));
            return this.transform(aggregation, groupId, context).map(newAggregation -> Rule.Result.ofPlanNode(exchange.replaceChildren((List<PlanNode>)ImmutableList.of((Object)project.replaceChildren((List)ImmutableList.of((Object)newAggregation)))))).orElseGet(Rule.Result::empty);
        }
    }

    @VisibleForTesting
    class AddExchangesBelowExchangePartialAggregationGroupId
    extends BaseAddExchangesBelowExchangePartialAggregationGroupId {
        AddExchangesBelowExchangePartialAggregationGroupId() {
        }

        @Override
        public Pattern<ExchangeNode> getPattern() {
            return WITHOUT_PROJECTION;
        }

        @Override
        public Rule.Result apply(ExchangeNode exchange, Captures captures, Rule.Context context) {
            AggregationNode aggregation = (AggregationNode)captures.get(AGGREGATION);
            GroupIdNode groupId = (GroupIdNode)((Object)captures.get(GROUP_ID));
            return this.transform(aggregation, groupId, context).map(newAggregation -> {
                PlanNode newExchange = exchange.replaceChildren((List<PlanNode>)ImmutableList.of((Object)newAggregation));
                return Rule.Result.ofPlanNode(newExchange);
            }).orElseGet(Rule.Result::empty);
        }
    }

    private abstract class BaseAddExchangesBelowExchangePartialAggregationGroupId
    implements Rule<ExchangeNode> {
        private BaseAddExchangesBelowExchangePartialAggregationGroupId() {
        }

        @Override
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.isEnabledAddExchangeBelowGroupId(session);
        }

        protected Optional<PlanNode> transform(AggregationNode aggregation, GroupIdNode groupId, Rule.Context context) {
            StreamPropertyDerivations.StreamProperties sourceProperties;
            Set groupingKeys = (Set)aggregation.getGroupingKeys().stream().filter(symbol -> !groupId.getGroupIdVariable().equals(symbol)).collect(ImmutableSet.toImmutableSet());
            Multiset groupingSetHistogram = (Multiset)groupId.getGroupingSets().stream().flatMap(Collection::stream).collect(ImmutableMultiset.toImmutableMultiset());
            if (!Objects.equals(groupingSetHistogram.elementSet(), groupingKeys)) {
                return Optional.empty();
            }
            double aggregationMemoryRequirements = this.estimateAggregationMemoryRequirements(groupingKeys, groupId, (Multiset<VariableReferenceExpression>)groupingSetHistogram, context);
            if (Double.isNaN(aggregationMemoryRequirements) || aggregationMemoryRequirements < (double)AddExchangesBelowPartialAggregationOverGroupIdRuleSet.this.maxPartialAggregationMemoryUsage.toBytes()) {
                return Optional.empty();
            }
            List desiredHashVariables = (List)groupingSetHistogram.entrySet().stream().filter(entry -> (double)entry.getCount() >= (double)groupId.getGroupingSets().size() * 0.5).map(Multiset.Entry::getElement).peek(symbol -> Verify.verify((boolean)groupingKeys.contains(symbol), (String)"%s not found in the grouping keys [%s]", (Object)symbol, (Object)groupingKeys)).map(groupId.getGroupingColumns()::get).collect(ImmutableList.toImmutableList());
            PlanNodeStatsEstimate sourceStats = context.getStatsProvider().getStats(groupId.getSource());
            desiredHashVariables = desiredHashVariables.stream().filter(symbol -> !Double.isNaN(sourceStats.getVariableStatistics((VariableReferenceExpression)symbol).getDistinctValuesCount())).max(Comparator.comparing(symbol -> sourceStats.getVariableStatistics((VariableReferenceExpression)symbol).getDistinctValuesCount())).map(symbol -> ImmutableList.of((Object)symbol)).orElse(desiredHashVariables);
            StreamPreferredProperties requiredProperties = StreamPreferredProperties.fixedParallelism().withPartitioning(desiredHashVariables);
            if (requiredProperties.isSatisfiedBy(sourceProperties = this.derivePropertiesRecursively(groupId.getSource(), context))) {
                return Optional.empty();
            }
            double estimatedGroups = this.estimateGroupCount(desiredHashVariables, context.getStatsProvider().getStats(groupId.getSource()));
            if (Double.isNaN(estimatedGroups) || estimatedGroups * 3.0 < (double)this.maximalConcurrencyAfterRepartition(context)) {
                return Optional.empty();
            }
            PlanNode source = groupId.getSource();
            source = ExchangeNode.partitionedExchange(context.getIdAllocator().getNextId(), ExchangeNode.Scope.REMOTE_STREAMING, source, new PartitioningScheme(Partitioning.create((PartitioningHandle)SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION, (Collection)desiredHashVariables), source.getOutputVariables()));
            source = ExchangeNode.partitionedExchange(context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, source, new PartitioningScheme(Partitioning.create((PartitioningHandle)SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION, (Collection)desiredHashVariables), source.getOutputVariables()));
            PlanNode newGroupId = groupId.replaceChildren((List<PlanNode>)ImmutableList.of((Object)source));
            PlanNode newAggregation = aggregation.replaceChildren((List)ImmutableList.of((Object)newGroupId));
            return Optional.of(newAggregation);
        }

        private int maximalConcurrencyAfterRepartition(Rule.Context context) {
            return SystemSessionProperties.getTaskConcurrency(context.getSession()) * AddExchangesBelowPartialAggregationOverGroupIdRuleSet.this.taskCountEstimator.estimateHashedTaskCount(context.getSession());
        }

        private double estimateAggregationMemoryRequirements(Set<VariableReferenceExpression> groupingKeys, GroupIdNode groupId, Multiset<VariableReferenceExpression> groupingSetHistogram, Rule.Context context) {
            Preconditions.checkArgument((boolean)Objects.equals(groupingSetHistogram.elementSet(), groupingKeys));
            PlanNodeStatsEstimate sourceStats = context.getStatsProvider().getStats(groupId.getSource());
            double keysMemoryRequirements = 0.0;
            for (List<VariableReferenceExpression> groupingSet : groupId.getGroupingSets()) {
                List sourceVariables = (List)groupingSet.stream().map(groupId.getGroupingColumns()::get).collect(ImmutableList.toImmutableList());
                double keyWidth = sourceStats.getOutputSizeForVariables(sourceVariables) / sourceStats.getOutputRowCount();
                double keyNdv = Math.min(this.estimateGroupCount(sourceVariables, sourceStats), sourceStats.getOutputRowCount());
                keysMemoryRequirements += keyWidth * keyNdv;
            }
            return keysMemoryRequirements;
        }

        private double estimateGroupCount(List<VariableReferenceExpression> variables, PlanNodeStatsEstimate statsEstimate) {
            return variables.stream().map(statsEstimate::getVariableStatistics).mapToDouble(this::ndvIncludingNull).reduce(1.0, (a, b) -> a * b);
        }

        private double ndvIncludingNull(VariableStatsEstimate variableStatsEstimate) {
            if (variableStatsEstimate.getNullsFraction() == 0.0) {
                return variableStatsEstimate.getDistinctValuesCount();
            }
            return variableStatsEstimate.getDistinctValuesCount() + 1.0;
        }

        private StreamPropertyDerivations.StreamProperties derivePropertiesRecursively(PlanNode node, Rule.Context context) {
            PlanNode resolvedPlanNode = context.getLookup().resolve(node);
            List inputProperties = (List)resolvedPlanNode.getSources().stream().map(source -> this.derivePropertiesRecursively((PlanNode)source, context)).collect(ImmutableList.toImmutableList());
            return StreamPropertyDerivations.deriveProperties(resolvedPlanNode, inputProperties, AddExchangesBelowPartialAggregationOverGroupIdRuleSet.this.metadata, context.getSession(), AddExchangesBelowPartialAggregationOverGroupIdRuleSet.this.nativeExecution);
        }
    }
}

