/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.FilterStatsCalculator;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.SemiJoinStatsCalculator;
import com.facebook.presto.cost.SimpleStatsRule;
import com.facebook.presto.cost.StatsNormalizer;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.relational.ProjectNodeUtils;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class SimpleFilterProjectSemiJoinStatsRule
extends SimpleStatsRule<FilterNode> {
    private static final Pattern<FilterNode> PATTERN = Patterns.filter();
    private final FilterStatsCalculator filterStatsCalculator;
    private final LogicalRowExpressions logicalRowExpressions;
    private final FunctionResolution functionResolution;

    public SimpleFilterProjectSemiJoinStatsRule(StatsNormalizer normalizer, FilterStatsCalculator filterStatsCalculator, FunctionManager functionManager) {
        super(normalizer);
        this.filterStatsCalculator = Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator can not be null");
        Objects.requireNonNull(functionManager, "functionManager can not be null");
        this.logicalRowExpressions = new LogicalRowExpressions((DeterminismEvaluator)new RowExpressionDeterminismEvaluator(functionManager), (StandardFunctionResolution)new FunctionResolution(functionManager), (FunctionMetadataManager)functionManager);
        this.functionResolution = new FunctionResolution(functionManager);
    }

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) {
        SemiJoinNode semiJoinNode;
        PlanNode nodeSource = lookup.resolve(node.getSource());
        if (nodeSource instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode)nodeSource;
            if (!ProjectNodeUtils.isIdentity(projectNode)) {
                return Optional.empty();
            }
            PlanNode projectNodeSource = lookup.resolve(projectNode.getSource());
            if (!(projectNodeSource instanceof SemiJoinNode)) {
                return Optional.empty();
            }
            semiJoinNode = (SemiJoinNode)projectNodeSource;
        } else if (nodeSource instanceof SemiJoinNode) {
            semiJoinNode = (SemiJoinNode)nodeSource;
        } else {
            return Optional.empty();
        }
        return this.calculate(node, semiJoinNode, sourceStats, session, types);
    }

    private Optional<PlanNodeStatsEstimate> calculate(FilterNode filterNode, SemiJoinNode semiJoinNode, StatsProvider statsProvider, Session session, TypeProvider types) {
        PlanNodeStatsEstimate sourceStats = statsProvider.getStats(semiJoinNode.getSource());
        PlanNodeStatsEstimate filteringSourceStats = statsProvider.getStats(semiJoinNode.getFilteringSource());
        VariableReferenceExpression filteringSourceJoinVariable = semiJoinNode.getFilteringSourceJoinVariable();
        VariableReferenceExpression sourceJoinVariable = semiJoinNode.getSourceJoinVariable();
        VariableReferenceExpression semiJoinOutput = semiJoinNode.getSemiJoinOutput();
        Optional<SemiJoinOutputFilter> semiJoinOutputFilter = OriginalExpressionUtils.isExpression(filterNode.getPredicate()) ? this.extractSemiJoinOutputFilter(OriginalExpressionUtils.castToExpression(filterNode.getPredicate()), semiJoinOutput) : this.extractSemiJoinOutputFilter(filterNode.getPredicate(), (RowExpression)semiJoinOutput);
        if (!semiJoinOutputFilter.isPresent()) {
            return Optional.empty();
        }
        PlanNodeStatsEstimate semiJoinStats = semiJoinOutputFilter.get().isNegated() ? SemiJoinStatsCalculator.computeAntiJoin(sourceStats, filteringSourceStats, sourceJoinVariable, filteringSourceJoinVariable) : SemiJoinStatsCalculator.computeSemiJoin(sourceStats, filteringSourceStats, sourceJoinVariable, filteringSourceJoinVariable);
        if (semiJoinStats.isOutputRowCountUnknown()) {
            return Optional.of(PlanNodeStatsEstimate.unknown());
        }
        PlanNodeStatsEstimate filteredStats = OriginalExpressionUtils.isExpression(filterNode.getPredicate()) ? this.filterStatsCalculator.filterStats(semiJoinStats, OriginalExpressionUtils.castToExpression(semiJoinOutputFilter.get().getRemainingPredicate()), session, types) : this.filterStatsCalculator.filterStats(semiJoinStats, semiJoinOutputFilter.get().getRemainingPredicate(), session, types);
        if (filteredStats.isOutputRowCountUnknown()) {
            return Optional.of(semiJoinStats.mapOutputRowCount(rowCount -> rowCount * 0.9));
        }
        return Optional.of(filteredStats);
    }

    private Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter(Expression predicate, VariableReferenceExpression semiJoinOutput) {
        List<Expression> conjuncts = ExpressionUtils.extractConjuncts(predicate);
        List semiJoinOutputReferences = (List)conjuncts.stream().filter(conjunct -> SimpleFilterProjectSemiJoinStatsRule.isSemiJoinOutputReference(conjunct, semiJoinOutput)).collect(ImmutableList.toImmutableList());
        if (semiJoinOutputReferences.size() != 1) {
            return Optional.empty();
        }
        Expression semiJoinOutputReference = (Expression)Iterables.getOnlyElement((Iterable)semiJoinOutputReferences);
        Expression remainingPredicate = ExpressionUtils.combineConjuncts((Collection)conjuncts.stream().filter(conjunct -> conjunct != semiJoinOutputReference).collect(ImmutableList.toImmutableList()));
        boolean negated = semiJoinOutputReference instanceof NotExpression;
        return Optional.of(new SemiJoinOutputFilter(negated, OriginalExpressionUtils.castToRowExpression(remainingPredicate)));
    }

    private Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter(RowExpression predicate, RowExpression input) {
        Preconditions.checkState((!OriginalExpressionUtils.isExpression(predicate) ? 1 : 0) != 0);
        List conjuncts = LogicalRowExpressions.extractConjuncts((RowExpression)predicate);
        List semiJoinOutputReferences = (List)conjuncts.stream().filter(conjunct -> this.isSemiJoinOutputReference((RowExpression)conjunct, input)).collect(ImmutableList.toImmutableList());
        if (semiJoinOutputReferences.size() != 1) {
            return Optional.empty();
        }
        RowExpression semiJoinOutputReference = (RowExpression)Iterables.getOnlyElement((Iterable)semiJoinOutputReferences);
        RowExpression remainingPredicate = this.logicalRowExpressions.combineConjuncts((Collection)conjuncts.stream().filter(conjunct -> conjunct != semiJoinOutputReference).collect(ImmutableList.toImmutableList()));
        boolean negated = this.isNotFunction(semiJoinOutputReference);
        return Optional.of(new SemiJoinOutputFilter(negated, remainingPredicate));
    }

    private boolean isSemiJoinOutputReference(RowExpression conjunct, RowExpression input) {
        return conjunct.equals((Object)input) || this.isNotFunction(conjunct) && ((RowExpression)((CallExpression)conjunct).getArguments().get(0)).equals((Object)input);
    }

    private static boolean isSemiJoinOutputReference(Expression conjunct, VariableReferenceExpression semiJoinOutput) {
        SymbolReference semiJoinOuputSymbolReference = new SymbolReference(semiJoinOutput.getName());
        return conjunct.equals((Object)semiJoinOuputSymbolReference) || conjunct instanceof NotExpression && ((NotExpression)conjunct).getValue().equals((Object)semiJoinOuputSymbolReference);
    }

    private boolean isNotFunction(RowExpression expression) {
        return expression instanceof CallExpression && this.functionResolution.isNotFunction(((CallExpression)expression).getFunctionHandle());
    }

    private static class SemiJoinOutputFilter {
        private final boolean negated;
        private final RowExpression remainingPredicate;

        public SemiJoinOutputFilter(boolean negated, RowExpression remainingPredicate) {
            this.negated = negated;
            this.remainingPredicate = Objects.requireNonNull(remainingPredicate, "remainingPredicate can not be null");
        }

        public boolean isNegated() {
            return this.negated;
        }

        public RowExpression getRemainingPredicate() {
            return this.remainingPredicate;
        }
    }
}

