/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.CostCalculator;
import com.facebook.presto.cost.PlanNodeCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.execution.scheduler.NodeSchedulerConfig;
import com.facebook.presto.metadata.InternalNodeManager;
import com.facebook.presto.spi.Node;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.OutputNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.IntSupplier;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

@ThreadSafe
public class CostCalculatorUsingExchanges
implements CostCalculator {
    private final IntSupplier numberOfNodes;

    @Inject
    public CostCalculatorUsingExchanges(NodeSchedulerConfig nodeSchedulerConfig, InternalNodeManager nodeManager) {
        this(CostCalculatorUsingExchanges.currentNumberOfWorkerNodes(nodeSchedulerConfig.isIncludeCoordinator(), nodeManager));
    }

    static IntSupplier currentNumberOfWorkerNodes(boolean includeCoordinator, InternalNodeManager nodeManager) {
        Objects.requireNonNull(nodeManager, "nodeManager is null");
        return () -> {
            Set<Node> activeNodes = nodeManager.getAllNodes().getActiveNodes();
            if (includeCoordinator) {
                return activeNodes.size();
            }
            return Math.toIntExact(activeNodes.stream().filter(node -> !node.isCoordinator()).count());
        };
    }

    public CostCalculatorUsingExchanges(IntSupplier numberOfNodes) {
        this.numberOfNodes = Objects.requireNonNull(numberOfNodes, "numberOfNodes is null");
    }

    @Override
    public PlanNodeCostEstimate calculateCost(PlanNode node, StatsProvider stats, Lookup lookup, Session session, Map<Symbol, Type> types) {
        CostEstimator costEstimator = new CostEstimator(this.numberOfNodes.getAsInt(), stats);
        return node.accept(costEstimator, null);
    }

    public static PlanNodeCostEstimate calculateExchangeCost(int numberOfNodes, PlanNodeStatsEstimate exchangeStats, List<Symbol> symbols, ExchangeNode.Type type, ExchangeNode.Scope scope) {
        double network;
        double exchangeSize = exchangeStats.getOutputSizeInBytes(symbols);
        double cpu = 0.0;
        switch (type) {
            case GATHER: {
                network = exchangeSize;
                break;
            }
            case REPARTITION: {
                network = exchangeSize;
                cpu = exchangeSize;
                break;
            }
            case REPLICATE: {
                network = exchangeSize * (double)numberOfNodes;
                break;
            }
            default: {
                throw new UnsupportedOperationException(String.format("Unsupported type [%s] of the exchange", new Object[]{type}));
            }
        }
        if (scope == ExchangeNode.Scope.LOCAL) {
            network = 0.0;
        }
        return new PlanNodeCostEstimate(cpu, 0.0, network);
    }

    private static class CostEstimator
    extends PlanVisitor<PlanNodeCostEstimate, Void> {
        private final int numberOfNodes;
        private final StatsProvider stats;

        CostEstimator(int numberOfNodes, StatsProvider stats) {
            this.numberOfNodes = numberOfNodes;
            this.stats = Objects.requireNonNull(stats, "stats is null");
        }

        @Override
        protected PlanNodeCostEstimate visitPlan(PlanNode node, Void context) {
            return PlanNodeCostEstimate.UNKNOWN_COST;
        }

        @Override
        public PlanNodeCostEstimate visitGroupReference(GroupReference node, Void context) {
            throw new UnsupportedOperationException();
        }

        @Override
        public PlanNodeCostEstimate visitOutput(OutputNode node, Void context) {
            return PlanNodeCostEstimate.ZERO_COST;
        }

        @Override
        public PlanNodeCostEstimate visitTableScan(TableScanNode node, Void context) {
            return PlanNodeCostEstimate.cpuCost(this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols()));
        }

        @Override
        public PlanNodeCostEstimate visitFilter(FilterNode node, Void context) {
            return PlanNodeCostEstimate.cpuCost(this.getStats(node.getSource()).getOutputSizeInBytes(node.getOutputSymbols()));
        }

        @Override
        public PlanNodeCostEstimate visitProject(ProjectNode node, Void context) {
            return PlanNodeCostEstimate.cpuCost(this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols()));
        }

        @Override
        public PlanNodeCostEstimate visitAggregation(AggregationNode node, Void context) {
            PlanNodeStatsEstimate aggregationStats = this.getStats(node);
            PlanNodeStatsEstimate sourceStats = this.getStats(node.getSource());
            double cpuCost = sourceStats.getOutputSizeInBytes(node.getSource().getOutputSymbols());
            double memoryCost = aggregationStats.getOutputSizeInBytes(node.getOutputSymbols());
            return new PlanNodeCostEstimate(cpuCost, memoryCost, 0.0);
        }

        @Override
        public PlanNodeCostEstimate visitJoin(JoinNode node, Void context) {
            return this.calculateJoinCost(node, node.getLeft(), node.getRight(), Objects.equals(node.getDistributionType(), Optional.of(JoinNode.DistributionType.REPLICATED)));
        }

        private PlanNodeCostEstimate calculateJoinCost(PlanNode join, PlanNode probe, PlanNode build, boolean replicated) {
            int numberOfNodesMultiplier = replicated ? this.numberOfNodes : 1;
            PlanNodeStatsEstimate probeStats = this.getStats(probe);
            PlanNodeStatsEstimate buildStats = this.getStats(build);
            PlanNodeStatsEstimate outputStats = this.getStats(join);
            double buildSideSize = buildStats.getOutputSizeInBytes(build.getOutputSymbols());
            double probeSideSize = probeStats.getOutputSizeInBytes(probe.getOutputSymbols());
            double joinOutputSize = outputStats.getOutputSizeInBytes(join.getOutputSymbols());
            double cpuCost = probeSideSize + buildSideSize * (double)numberOfNodesMultiplier + joinOutputSize;
            if (replicated) {
                cpuCost += buildSideSize * (double)(numberOfNodesMultiplier - 1);
            }
            double memoryCost = buildSideSize * (double)numberOfNodesMultiplier;
            return new PlanNodeCostEstimate(cpuCost, memoryCost, 0.0);
        }

        @Override
        public PlanNodeCostEstimate visitExchange(ExchangeNode node, Void context) {
            return CostCalculatorUsingExchanges.calculateExchangeCost(this.numberOfNodes, this.getStats(node), node.getOutputSymbols(), node.getType(), node.getScope());
        }

        @Override
        public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context) {
            return this.calculateJoinCost(node, node.getSource(), node.getFilteringSource(), node.getDistributionType().orElse(SemiJoinNode.DistributionType.PARTITIONED).equals((Object)SemiJoinNode.DistributionType.REPLICATED));
        }

        @Override
        public PlanNodeCostEstimate visitValues(ValuesNode node, Void context) {
            return PlanNodeCostEstimate.ZERO_COST;
        }

        @Override
        public PlanNodeCostEstimate visitEnforceSingleRow(EnforceSingleRowNode node, Void context) {
            return PlanNodeCostEstimate.ZERO_COST;
        }

        @Override
        public PlanNodeCostEstimate visitLimit(LimitNode node, Void context) {
            return PlanNodeCostEstimate.cpuCost(this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols()));
        }

        private PlanNodeStatsEstimate getStats(PlanNode node) {
            return this.stats.getStats(node);
        }
    }
}

