/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.CostCalculator;
import com.facebook.presto.cost.PlanNodeCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.cost.TaskCountEstimator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SpatialJoinNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

@ThreadSafe
public class CostCalculatorWithEstimatedExchanges
implements CostCalculator {
    private final CostCalculator costCalculator;
    private final TaskCountEstimator taskCountEstimator;

    @Inject
    public CostCalculatorWithEstimatedExchanges(CostCalculator costCalculator, TaskCountEstimator taskCountEstimator) {
        this.costCalculator = Objects.requireNonNull(costCalculator, "costCalculator is null");
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override
    public PlanNodeCostEstimate calculateCost(PlanNode node, StatsProvider stats, Session session, TypeProvider types) {
        ExchangeCostEstimator exchangeCostEstimator = new ExchangeCostEstimator(stats, types, this.taskCountEstimator);
        PlanNodeCostEstimate estimatedExchangeCost = node.accept(exchangeCostEstimator, null);
        return this.costCalculator.calculateCost(node, stats, session, types).add(estimatedExchangeCost);
    }

    public static PlanNodeCostEstimate calculateRemoteGatherCost(double inputSizeInBytes) {
        return PlanNodeCostEstimate.networkCost(inputSizeInBytes);
    }

    public static PlanNodeCostEstimate calculateRemoteRepartitionCost(double inputSizeInBytes) {
        return new PlanNodeCostEstimate(inputSizeInBytes, 0.0, inputSizeInBytes);
    }

    public static PlanNodeCostEstimate calculateLocalRepartitionCost(double inputSizeInBytes) {
        return PlanNodeCostEstimate.cpuCost(inputSizeInBytes);
    }

    public static PlanNodeCostEstimate calculateRemoteReplicateCost(double inputSizeInBytes, int destinationTaskCount) {
        return PlanNodeCostEstimate.networkCost(inputSizeInBytes * (double)destinationTaskCount);
    }

    public static PlanNodeCostEstimate calculateJoinExchangeCost(PlanNode probe, PlanNode build, StatsProvider stats, TypeProvider types, boolean replicated, int estimatedSourceDistributedTaskCount) {
        double probeSizeInBytes = stats.getStats(probe).getOutputSizeInBytes(probe.getOutputSymbols(), types);
        double buildSizeInBytes = stats.getStats(build).getOutputSizeInBytes(build.getOutputSymbols(), types);
        if (replicated) {
            PlanNodeCostEstimate replicateCost = CostCalculatorWithEstimatedExchanges.calculateRemoteReplicateCost(buildSizeInBytes, estimatedSourceDistributedTaskCount);
            PlanNodeCostEstimate localRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(buildSizeInBytes);
            return replicateCost.add(localRepartitionCost);
        }
        PlanNodeCostEstimate probeCost = CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(probeSizeInBytes);
        PlanNodeCostEstimate buildRemoteRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(buildSizeInBytes);
        PlanNodeCostEstimate buildLocalRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(buildSizeInBytes);
        return probeCost.add(buildRemoteRepartitionCost).add(buildLocalRepartitionCost);
    }

    public static PlanNodeCostEstimate calculateJoinInputCost(PlanNode probe, PlanNode build, StatsProvider stats, TypeProvider types, boolean replicated, int estimatedSourceDistributedTaskCount) {
        int buildSizeMultiplier = replicated ? estimatedSourceDistributedTaskCount : 1;
        PlanNodeStatsEstimate probeStats = stats.getStats(probe);
        PlanNodeStatsEstimate buildStats = stats.getStats(build);
        double buildSideSize = buildStats.getOutputSizeInBytes(build.getOutputSymbols(), types);
        double probeSideSize = probeStats.getOutputSizeInBytes(probe.getOutputSymbols(), types);
        double cpuCost = probeSideSize + buildSideSize * (double)buildSizeMultiplier;
        if (replicated) {
            cpuCost += buildSideSize * (double)(buildSizeMultiplier - 1);
        }
        double memoryCost = buildSideSize * (double)buildSizeMultiplier;
        return new PlanNodeCostEstimate(cpuCost, memoryCost, 0.0);
    }

    private static class ExchangeCostEstimator
    extends PlanVisitor<PlanNodeCostEstimate, Void> {
        private final StatsProvider stats;
        private final TypeProvider types;
        private final TaskCountEstimator taskCountEstimator;

        ExchangeCostEstimator(StatsProvider stats, TypeProvider types, TaskCountEstimator taskCountEstimator) {
            this.stats = Objects.requireNonNull(stats, "stats is null");
            this.types = Objects.requireNonNull(types, "types is null");
            this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
        }

        @Override
        protected PlanNodeCostEstimate visitPlan(PlanNode node, Void context) {
            return PlanNodeCostEstimate.zero();
        }

        @Override
        public PlanNodeCostEstimate visitGroupReference(GroupReference node, Void context) {
            throw new UnsupportedOperationException();
        }

        @Override
        public PlanNodeCostEstimate visitAggregation(AggregationNode node, Void context) {
            PlanNode source = node.getSource();
            double inputSizeInBytes = this.getStats(source).getOutputSizeInBytes(source.getOutputSymbols(), this.types);
            PlanNodeCostEstimate remoteRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(inputSizeInBytes);
            PlanNodeCostEstimate localRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(inputSizeInBytes);
            return remoteRepartitionCost.add(localRepartitionCost);
        }

        @Override
        public PlanNodeCostEstimate visitJoin(JoinNode node, Void context) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(node.getLeft(), node.getRight(), this.stats, this.types, Objects.equals(node.getDistributionType(), Optional.of(JoinNode.DistributionType.REPLICATED)), this.taskCountEstimator.estimateSourceDistributedTaskCount());
        }

        @Override
        public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(node.getSource(), node.getFilteringSource(), this.stats, this.types, Objects.equals(node.getDistributionType(), Optional.of(SemiJoinNode.DistributionType.REPLICATED)), this.taskCountEstimator.estimateSourceDistributedTaskCount());
        }

        @Override
        public PlanNodeCostEstimate visitSpatialJoin(SpatialJoinNode node, Void context) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(node.getLeft(), node.getRight(), this.stats, this.types, node.getDistributionType() == SpatialJoinNode.DistributionType.REPLICATED, this.taskCountEstimator.estimateSourceDistributedTaskCount());
        }

        @Override
        public PlanNodeCostEstimate visitUnion(UnionNode node, Void context) {
            double inputSizeInBytes = this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), this.types);
            return CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost(inputSizeInBytes);
        }

        private PlanNodeStatsEstimate getStats(PlanNode node) {
            return this.stats.getStats(node);
        }
    }
}

