/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsCalculator;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.Constraint;
import com.facebook.presto.spi.statistics.TableStatistics;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.OutputNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import java.util.Map;
import java.util.Objects;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

@ThreadSafe
public class CoefficientBasedStatsCalculator
implements StatsCalculator {
    private static final Double FILTER_COEFFICIENT = 0.5;
    private static final Double JOIN_MATCHING_COEFFICIENT = 2.0;
    private final Metadata metadata;

    @Inject
    public CoefficientBasedStatsCalculator(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override
    public PlanNodeStatsEstimate calculateStats(PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, Map<Symbol, Type> types) {
        Visitor visitor = new Visitor(sourceStats, session);
        return node.accept(visitor, null);
    }

    private class Visitor
    extends PlanVisitor<PlanNodeStatsEstimate, Void> {
        private final StatsProvider sourceStats;
        private final Session session;

        public Visitor(StatsProvider sourceStats, Session session) {
            this.sourceStats = Objects.requireNonNull(sourceStats, "sourceStats is null");
            this.session = Objects.requireNonNull(session, "session is null");
        }

        private PlanNodeStatsEstimate getStats(PlanNode sourceNode) {
            return this.sourceStats.getStats(sourceNode);
        }

        @Override
        protected PlanNodeStatsEstimate visitPlan(PlanNode node, Void context) {
            return PlanNodeStatsEstimate.UNKNOWN_STATS;
        }

        @Override
        public PlanNodeStatsEstimate visitGroupReference(GroupReference node, Void context) {
            throw new UnsupportedOperationException();
        }

        @Override
        public PlanNodeStatsEstimate visitOutput(OutputNode node, Void context) {
            return this.getStats(node.getSource());
        }

        @Override
        public PlanNodeStatsEstimate visitFilter(FilterNode node, Void context) {
            PlanNodeStatsEstimate sourceStats = this.getStats(node.getSource());
            return sourceStats.mapOutputRowCount(value -> value * FILTER_COEFFICIENT);
        }

        @Override
        public PlanNodeStatsEstimate visitProject(ProjectNode node, Void context) {
            return this.getStats(node.getSource());
        }

        @Override
        public PlanNodeStatsEstimate visitJoin(JoinNode node, Void context) {
            PlanNodeStatsEstimate leftStats = this.getStats(node.getLeft());
            PlanNodeStatsEstimate rightStats = this.getStats(node.getRight());
            PlanNodeStatsEstimate.Builder joinStats = PlanNodeStatsEstimate.builder();
            double rowCount = Math.max(leftStats.getOutputRowCount(), rightStats.getOutputRowCount()) * JOIN_MATCHING_COEFFICIENT;
            joinStats.setOutputRowCount(rowCount);
            return joinStats.build();
        }

        @Override
        public PlanNodeStatsEstimate visitExchange(ExchangeNode node, Void context) {
            double rowCount = 0.0;
            for (int i = 0; i < node.getSources().size(); ++i) {
                PlanNodeStatsEstimate sourceStat = this.getStats(node.getSources().get(i));
                rowCount += sourceStat.getOutputRowCount();
            }
            return PlanNodeStatsEstimate.builder().setOutputRowCount(rowCount).build();
        }

        @Override
        public PlanNodeStatsEstimate visitTableScan(TableScanNode node, Void context) {
            Constraint constraint = new Constraint(node.getCurrentConstraint(), bindings -> true);
            TableStatistics tableStatistics = CoefficientBasedStatsCalculator.this.metadata.getTableStatistics(this.session, node.getTable(), (Constraint<ColumnHandle>)constraint);
            return PlanNodeStatsEstimate.builder().setOutputRowCount(tableStatistics.getRowCount().getValue()).build();
        }

        @Override
        public PlanNodeStatsEstimate visitValues(ValuesNode node, Void context) {
            int valuesCount = node.getRows().size();
            return PlanNodeStatsEstimate.builder().setOutputRowCount(valuesCount).build();
        }

        @Override
        public PlanNodeStatsEstimate visitEnforceSingleRow(EnforceSingleRowNode node, Void context) {
            return PlanNodeStatsEstimate.builder().setOutputRowCount(1.0).build();
        }

        @Override
        public PlanNodeStatsEstimate visitSemiJoin(SemiJoinNode node, Void context) {
            PlanNodeStatsEstimate sourceStats = this.getStats(node.getSource());
            return sourceStats.mapOutputRowCount(rowCount -> rowCount * JOIN_MATCHING_COEFFICIENT);
        }

        @Override
        public PlanNodeStatsEstimate visitLimit(LimitNode node, Void context) {
            PlanNodeStatsEstimate sourceStats = this.getStats(node.getSource());
            PlanNodeStatsEstimate.Builder limitStats = PlanNodeStatsEstimate.builder();
            if (sourceStats.getOutputRowCount() < (double)node.getCount()) {
                limitStats.setOutputRowCount(sourceStats.getOutputRowCount());
            } else {
                limitStats.setOutputRowCount(node.getCount());
            }
            return limitStats.build();
        }
    }
}

