/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.CostCalculator;
import com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges;
import com.facebook.presto.cost.PlanNodeCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.cost.TaskCountEstimator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.OutputNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SpatialJoinNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

@ThreadSafe
public class CostCalculatorUsingExchanges
implements CostCalculator {
    private final TaskCountEstimator taskCountEstimator;

    @Inject
    public CostCalculatorUsingExchanges(TaskCountEstimator taskCountEstimator) {
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override
    public PlanNodeCostEstimate calculateCost(PlanNode node, StatsProvider stats, Session session, TypeProvider types) {
        CostEstimator costEstimator = new CostEstimator(stats, types, this.taskCountEstimator);
        return node.accept(costEstimator, null);
    }

    private static class CostEstimator
    extends PlanVisitor<PlanNodeCostEstimate, Void> {
        private final StatsProvider stats;
        private final TypeProvider types;
        private final TaskCountEstimator taskCountEstimator;

        CostEstimator(StatsProvider stats, TypeProvider types, TaskCountEstimator taskCountEstimator) {
            this.stats = Objects.requireNonNull(stats, "stats is null");
            this.types = Objects.requireNonNull(types, "types is null");
            this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
        }

        @Override
        protected PlanNodeCostEstimate visitPlan(PlanNode node, Void context) {
            return PlanNodeCostEstimate.unknown();
        }

        @Override
        public PlanNodeCostEstimate visitGroupReference(GroupReference node, Void context) {
            throw new UnsupportedOperationException();
        }

        @Override
        public PlanNodeCostEstimate visitAssignUniqueId(AssignUniqueId node, Void context) {
            return PlanNodeCostEstimate.cpuCost(this.getStats(node).getOutputSizeInBytes((Collection<Symbol>)ImmutableList.of((Object)node.getIdColumn()), this.types));
        }

        @Override
        public PlanNodeCostEstimate visitRowNumber(RowNumberNode node, Void context) {
            ImmutableList symbols = node.getOutputSymbols();
            if (!node.getMaxRowCountPerPartition().isPresent()) {
                symbols = ImmutableList.builder().addAll(node.getPartitionBy()).add((Object)node.getRowNumberSymbol()).build();
            }
            PlanNodeStatsEstimate stats = this.getStats(node);
            double cpuCost = stats.getOutputSizeInBytes((Collection<Symbol>)symbols, this.types);
            double memoryCost = node.getPartitionBy().isEmpty() ? 0.0 : stats.getOutputSizeInBytes(node.getSource().getOutputSymbols(), this.types);
            return new PlanNodeCostEstimate(cpuCost, memoryCost, 0.0);
        }

        @Override
        public PlanNodeCostEstimate visitOutput(OutputNode node, Void context) {
            return PlanNodeCostEstimate.zero();
        }

        @Override
        public PlanNodeCostEstimate visitTableScan(TableScanNode node, Void context) {
            return PlanNodeCostEstimate.cpuCost(this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), this.types));
        }

        @Override
        public PlanNodeCostEstimate visitFilter(FilterNode node, Void context) {
            return PlanNodeCostEstimate.cpuCost(this.getStats(node.getSource()).getOutputSizeInBytes(node.getOutputSymbols(), this.types));
        }

        @Override
        public PlanNodeCostEstimate visitProject(ProjectNode node, Void context) {
            return PlanNodeCostEstimate.cpuCost(this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), this.types));
        }

        @Override
        public PlanNodeCostEstimate visitAggregation(AggregationNode node, Void context) {
            if (node.getStep() != AggregationNode.Step.FINAL && node.getStep() != AggregationNode.Step.SINGLE) {
                return PlanNodeCostEstimate.unknown();
            }
            PlanNodeStatsEstimate aggregationStats = this.getStats(node);
            PlanNodeStatsEstimate sourceStats = this.getStats(node.getSource());
            double cpuCost = sourceStats.getOutputSizeInBytes(node.getSource().getOutputSymbols(), this.types);
            double memoryCost = aggregationStats.getOutputSizeInBytes(node.getOutputSymbols(), this.types);
            return new PlanNodeCostEstimate(cpuCost, memoryCost, 0.0);
        }

        @Override
        public PlanNodeCostEstimate visitJoin(JoinNode node, Void context) {
            return this.calculateJoinCost(node, node.getLeft(), node.getRight(), Objects.equals(node.getDistributionType(), Optional.of(JoinNode.DistributionType.REPLICATED)));
        }

        private PlanNodeCostEstimate calculateJoinCost(PlanNode join, PlanNode probe, PlanNode build, boolean replicated) {
            PlanNodeCostEstimate joinInputCost = CostCalculatorWithEstimatedExchanges.calculateJoinInputCost(probe, build, this.stats, this.types, replicated, this.taskCountEstimator.estimateSourceDistributedTaskCount());
            PlanNodeCostEstimate joinOutputCost = this.calculateJoinOutputCost(join);
            return joinInputCost.add(joinOutputCost);
        }

        private PlanNodeCostEstimate calculateJoinOutputCost(PlanNode join) {
            PlanNodeStatsEstimate outputStats = this.getStats(join);
            double joinOutputSize = outputStats.getOutputSizeInBytes(join.getOutputSymbols(), this.types);
            return PlanNodeCostEstimate.cpuCost(joinOutputSize);
        }

        @Override
        public PlanNodeCostEstimate visitExchange(ExchangeNode node, Void context) {
            double inputSizeInBytes = this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), this.types);
            switch (node.getScope()) {
                case LOCAL: {
                    switch (node.getType()) {
                        case GATHER: {
                            return PlanNodeCostEstimate.zero();
                        }
                        case REPARTITION: {
                            return CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(inputSizeInBytes);
                        }
                        case REPLICATE: {
                            return PlanNodeCostEstimate.zero();
                        }
                    }
                    throw new IllegalArgumentException("Unexpected type: " + (Object)((Object)node.getType()));
                }
                case REMOTE: {
                    switch (node.getType()) {
                        case GATHER: {
                            return CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost(inputSizeInBytes);
                        }
                        case REPARTITION: {
                            return CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(inputSizeInBytes);
                        }
                        case REPLICATE: {
                            return CostCalculatorWithEstimatedExchanges.calculateRemoteReplicateCost(inputSizeInBytes, this.taskCountEstimator.estimateSourceDistributedTaskCount());
                        }
                    }
                    throw new IllegalArgumentException("Unexpected type: " + (Object)((Object)node.getType()));
                }
            }
            throw new IllegalArgumentException("Unexpected scope: " + (Object)((Object)node.getScope()));
        }

        @Override
        public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context) {
            return this.calculateJoinCost(node, node.getSource(), node.getFilteringSource(), node.getDistributionType().orElse(SemiJoinNode.DistributionType.PARTITIONED).equals((Object)SemiJoinNode.DistributionType.REPLICATED));
        }

        @Override
        public PlanNodeCostEstimate visitSpatialJoin(SpatialJoinNode node, Void context) {
            return this.calculateJoinCost(node, node.getLeft(), node.getRight(), node.getDistributionType() == SpatialJoinNode.DistributionType.REPLICATED);
        }

        @Override
        public PlanNodeCostEstimate visitValues(ValuesNode node, Void context) {
            return PlanNodeCostEstimate.zero();
        }

        @Override
        public PlanNodeCostEstimate visitEnforceSingleRow(EnforceSingleRowNode node, Void context) {
            return PlanNodeCostEstimate.zero();
        }

        @Override
        public PlanNodeCostEstimate visitLimit(LimitNode node, Void context) {
            return PlanNodeCostEstimate.cpuCost(this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), this.types));
        }

        @Override
        public PlanNodeCostEstimate visitUnion(UnionNode node, Void context) {
            return PlanNodeCostEstimate.zero();
        }

        private PlanNodeStatsEstimate getStats(PlanNode node) {
            return this.stats.getStats(node);
        }
    }
}

