package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.PlanNodeCost;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.Constraint;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.statistics.Estimate;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.DomainTranslator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.OutputNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

@ThreadSafe
/* loaded from: input_file:com/facebook/presto/cost/CoefficientBasedCostCalculator.class */
public class CoefficientBasedCostCalculator implements CostCalculator {
    private static final Double FILTER_COEFFICIENT = Double.valueOf(0.5d);
    private static final Double JOIN_MATCHING_COEFFICIENT = Double.valueOf(2.0d);
    private final Metadata metadata;

    /* loaded from: input_file:com/facebook/presto/cost/CoefficientBasedCostCalculator$Visitor.class */
    private class Visitor extends PlanVisitor<Void, PlanNodeCost> {
        private final Session session;
        private final Map<PlanNodeId, PlanNodeCost> costs = new HashMap();
        private final Map<Symbol, Type> types;

        public Visitor(Session session, Map<Symbol, Type> map) {
            this.session = session;
            this.types = ImmutableMap.copyOf((Map) map);
        }

        public Map<PlanNodeId, PlanNodeCost> getCosts() {
            return ImmutableMap.copyOf((Map) this.costs);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCost visitPlan(PlanNode planNode, Void r6) {
            visitSources(planNode);
            this.costs.put(planNode.getId(), PlanNodeCost.UNKNOWN_COST);
            return PlanNodeCost.UNKNOWN_COST;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCost visitOutput(OutputNode outputNode, Void r5) {
            return copySourceCost(outputNode);
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCost visitFilter(FilterNode filterNode, Void r6) {
            PlanNodeCost visitTableScanWithPredicate = filterNode.getSource() instanceof TableScanNode ? visitTableScanWithPredicate((TableScanNode) filterNode.getSource(), filterNode.getPredicate()) : visitSource(filterNode);
            double doubleValue = CoefficientBasedCostCalculator.FILTER_COEFFICIENT.doubleValue();
            PlanNodeCost mapOutputRowCount = visitTableScanWithPredicate.mapOutputRowCount(d -> {
                return Double.valueOf(d.doubleValue() * doubleValue);
            });
            this.costs.put(filterNode.getId(), mapOutputRowCount);
            return mapOutputRowCount;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCost visitProject(ProjectNode projectNode, Void r5) {
            return copySourceCost(projectNode);
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCost visitJoin(JoinNode joinNode, Void r8) {
            List<PlanNodeCost> visitSources = visitSources(joinNode);
            PlanNodeCost planNodeCost = visitSources.get(0);
            PlanNodeCost planNodeCost2 = visitSources.get(1);
            PlanNodeCost.Builder builder = PlanNodeCost.builder();
            if (!planNodeCost.getOutputRowCount().isValueUnknown() && !planNodeCost2.getOutputRowCount().isValueUnknown()) {
                builder.setOutputRowCount(new Estimate(Math.max(planNodeCost.getOutputRowCount().getValue(), planNodeCost2.getOutputRowCount().getValue()) * CoefficientBasedCostCalculator.JOIN_MATCHING_COEFFICIENT.doubleValue()));
            }
            this.costs.put(joinNode.getId(), builder.build());
            return builder.build();
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCost visitExchange(ExchangeNode exchangeNode, Void r7) {
            List<PlanNodeCost> visitSources = visitSources(exchangeNode);
            Estimate estimate = new Estimate(CMAESOptimizer.DEFAULT_STOPFITNESS);
            for (PlanNodeCost planNodeCost : visitSources) {
                estimate = planNodeCost.getOutputRowCount().isValueUnknown() ? Estimate.unknownValue() : estimate.map(d -> {
                    return Double.valueOf(d.doubleValue() + planNodeCost.getOutputRowCount().getValue());
                });
            }
            PlanNodeCost build = PlanNodeCost.builder().setOutputRowCount(estimate).build();
            this.costs.put(exchangeNode.getId(), build);
            return build;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCost visitTableScan(TableScanNode tableScanNode, Void r6) {
            return visitTableScanWithPredicate(tableScanNode, BooleanLiteral.TRUE_LITERAL);
        }

        private PlanNodeCost visitTableScanWithPredicate(TableScanNode tableScanNode, Expression expression) {
            PlanNodeCost build = PlanNodeCost.builder().setOutputRowCount(CoefficientBasedCostCalculator.this.metadata.getTableStatistics(this.session, tableScanNode.getTable(), getConstraint(tableScanNode, expression)).getRowCount()).build();
            this.costs.put(tableScanNode.getId(), build);
            return build;
        }

        private Constraint<ColumnHandle> getConstraint(TableScanNode tableScanNode, Expression expression) {
            TupleDomain<Symbol> tupleDomain = DomainTranslator.fromPredicate(CoefficientBasedCostCalculator.this.metadata, this.session, expression, this.types).getTupleDomain();
            Map<Symbol, ColumnHandle> assignments = tableScanNode.getAssignments();
            assignments.getClass();
            return new Constraint<>(tupleDomain.transform((v1) -> {
                return r1.get(v1);
            }).intersect(tableScanNode.getCurrentConstraint()), map -> {
                return true;
            });
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCost visitValues(ValuesNode valuesNode, Void r7) {
            PlanNodeCost build = PlanNodeCost.builder().setOutputRowCount(new Estimate(valuesNode.getRows().size())).build();
            this.costs.put(valuesNode.getId(), build);
            return build;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCost visitEnforceSingleRow(EnforceSingleRowNode enforceSingleRowNode, Void r8) {
            visitSources(enforceSingleRowNode);
            PlanNodeCost build = PlanNodeCost.builder().setOutputRowCount(new Estimate(1.0d)).build();
            this.costs.put(enforceSingleRowNode.getId(), build);
            return build;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCost visitSemiJoin(SemiJoinNode semiJoinNode, Void r6) {
            visitSources(semiJoinNode);
            PlanNodeCost mapOutputRowCount = this.costs.get(semiJoinNode.getSource().getId()).mapOutputRowCount(d -> {
                return Double.valueOf(d.doubleValue() * CoefficientBasedCostCalculator.JOIN_MATCHING_COEFFICIENT.doubleValue());
            });
            this.costs.put(semiJoinNode.getId(), mapOutputRowCount);
            return mapOutputRowCount;
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNodeCost visitLimit(LimitNode limitNode, Void r8) {
            PlanNodeCost visitSource = visitSource(limitNode);
            PlanNodeCost.Builder builder = PlanNodeCost.builder();
            if (visitSource.getOutputRowCount().getValue() < limitNode.getCount()) {
                builder.setOutputRowCount(visitSource.getOutputRowCount());
            } else {
                builder.setOutputRowCount(new Estimate(limitNode.getCount()));
            }
            this.costs.put(limitNode.getId(), builder.build());
            return builder.build();
        }

        private PlanNodeCost copySourceCost(PlanNode planNode) {
            PlanNodeCost visitSource = visitSource(planNode);
            this.costs.put(planNode.getId(), visitSource);
            return visitSource;
        }

        private List<PlanNodeCost> visitSources(PlanNode planNode) {
            return (List) planNode.getSources().stream().map(planNode2 -> {
                return (PlanNodeCost) planNode2.accept(this, null);
            }).collect(Collectors.toList());
        }

        private PlanNodeCost visitSource(PlanNode planNode) {
            return (PlanNodeCost) Iterables.getOnlyElement(visitSources(planNode));
        }
    }

    @Inject
    public CoefficientBasedCostCalculator(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override // com.facebook.presto.cost.CostCalculator
    public Map<PlanNodeId, PlanNodeCost> calculateCostForPlan(Session session, Map<Symbol, Type> map, PlanNode planNode) {
        Visitor visitor = new Visitor(session, map);
        planNode.accept(visitor, null);
        return ImmutableMap.copyOf((Map) visitor.getCosts());
    }
}
