/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolsExtractor;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.optimizations.SymbolMapper;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class PlanNodeDecorrelator {
    private final PlanNodeIdAllocator idAllocator;
    private final Lookup lookup;

    public PlanNodeDecorrelator(PlanNodeIdAllocator idAllocator, Lookup lookup) {
        this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
        this.lookup = Objects.requireNonNull(lookup, "lookup is null");
    }

    public Optional<DecorrelatedNode> decorrelateFilters(PlanNode node, List<Symbol> correlation) {
        Optional<DecorrelationResult> decorrelationResultOptional = this.lookup.resolve(node).accept(new DecorrelatingVisitor(correlation), null);
        return decorrelationResultOptional.flatMap(decorrelationResult -> this.decorrelatedNode(decorrelationResult.correlatedPredicates, decorrelationResult.node, correlation));
    }

    private Optional<DecorrelatedNode> decorrelatedNode(List<Expression> correlatedPredicates, PlanNode node, List<Symbol> correlation) {
        if (this.containsCorrelation(node, correlation)) {
            return Optional.empty();
        }
        return Optional.of(new DecorrelatedNode(correlatedPredicates, node));
    }

    private boolean containsCorrelation(PlanNode node, List<Symbol> correlation) {
        return Sets.union(SymbolsExtractor.extractUnique(node, this.lookup), SymbolsExtractor.extractOutputSymbols(node, this.lookup)).stream().anyMatch(correlation::contains);
    }

    public static class DecorrelatedNode {
        private final List<Expression> correlatedPredicates;
        private final PlanNode node;

        public DecorrelatedNode(List<Expression> correlatedPredicates, PlanNode node) {
            Objects.requireNonNull(correlatedPredicates, "correlatedPredicates is null");
            this.correlatedPredicates = ImmutableList.copyOf(correlatedPredicates);
            this.node = Objects.requireNonNull(node, "node is null");
        }

        public Optional<Expression> getCorrelatedPredicates() {
            if (this.correlatedPredicates.isEmpty()) {
                return Optional.empty();
            }
            return Optional.of(ExpressionUtils.and(this.correlatedPredicates));
        }

        public PlanNode getNode() {
            return this.node;
        }
    }

    private static class DecorrelationResult {
        final PlanNode node;
        final Set<Symbol> symbolsToPropagate;
        final List<Expression> correlatedPredicates;
        final Multimap<Symbol, Symbol> correlatedSymbolsMapping;
        final boolean atMostSingleRow;

        DecorrelationResult(PlanNode node, Set<Symbol> symbolsToPropagate, List<Expression> correlatedPredicates, Multimap<Symbol, Symbol> correlatedSymbolsMapping, boolean atMostSingleRow) {
            this.node = node;
            this.symbolsToPropagate = symbolsToPropagate;
            this.correlatedPredicates = correlatedPredicates;
            this.atMostSingleRow = atMostSingleRow;
            this.correlatedSymbolsMapping = correlatedSymbolsMapping;
            Preconditions.checkState((boolean)symbolsToPropagate.containsAll(correlatedSymbolsMapping.values()), (Object)"Expected symbols to propagate to contain all constant symbols");
        }

        SymbolMapper getCorrelatedSymbolMapper() {
            return new SymbolMapper((Map)this.correlatedSymbolsMapping.asMap().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, symbols -> (Symbol)Iterables.getLast((Iterable)((Iterable)symbols.getValue())))));
        }

        Set<Symbol> getConstantSymbols() {
            return ImmutableSet.copyOf((Collection)this.correlatedSymbolsMapping.values());
        }
    }

    private class DecorrelatingVisitor
    extends InternalPlanVisitor<Optional<DecorrelationResult>, Void> {
        final List<Symbol> correlation;

        DecorrelatingVisitor(List<Symbol> correlation) {
            this.correlation = Objects.requireNonNull(correlation, "correlation is null");
        }

        @Override
        protected Optional<DecorrelationResult> visitPlan(PlanNode node, Void context) {
            return Optional.of(new DecorrelationResult(node, (Set<Symbol>)ImmutableSet.of(), (List<Expression>)ImmutableList.of(), (Multimap<Symbol, Symbol>)ImmutableMultimap.of(), false));
        }

        @Override
        public Optional<DecorrelationResult> visitFilter(FilterNode node, Void context) {
            Optional<DecorrelationResult> childDecorrelationResultOptional = Optional.of(new DecorrelationResult(node.getSource(), (Set<Symbol>)ImmutableSet.of(), (List<Expression>)ImmutableList.of(), (Multimap<Symbol, Symbol>)ImmutableMultimap.of(), false));
            if (PlanNodeDecorrelator.this.containsCorrelation(node.getSource(), this.correlation)) {
                childDecorrelationResultOptional = PlanNodeDecorrelator.this.lookup.resolve(node.getSource()).accept(this, null);
            }
            if (!childDecorrelationResultOptional.isPresent()) {
                return Optional.empty();
            }
            Expression predicate = OriginalExpressionUtils.castToExpression(node.getPredicate());
            Map<Boolean, List<Expression>> predicates = ExpressionUtils.extractConjuncts(predicate).stream().collect(Collectors.partitioningBy(this::isCorrelated));
            ImmutableList correlatedPredicates = ImmutableList.copyOf((Collection)predicates.get(true));
            ImmutableList uncorrelatedPredicates = ImmutableList.copyOf((Collection)predicates.get(false));
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            FilterNode newFilterNode = new FilterNode(PlanNodeDecorrelator.this.idAllocator.getNextId(), childDecorrelationResult.node, OriginalExpressionUtils.castToRowExpression(ExpressionUtils.combineConjuncts((Collection<Expression>)uncorrelatedPredicates)));
            Sets.SetView symbolsToPropagate = Sets.difference(SymbolsExtractor.extractUnique((Iterable<? extends Expression>)correlatedPredicates), (Set)ImmutableSet.copyOf(this.correlation));
            return Optional.of(new DecorrelationResult(newFilterNode, (Set<Symbol>)Sets.union(childDecorrelationResult.symbolsToPropagate, (Set)symbolsToPropagate), (List<Expression>)ImmutableList.builder().addAll(childDecorrelationResult.correlatedPredicates).addAll((Iterable)correlatedPredicates).build(), (Multimap<Symbol, Symbol>)ImmutableMultimap.builder().putAll(childDecorrelationResult.correlatedSymbolsMapping).putAll(this.extractCorrelatedSymbolsMapping((List<Expression>)correlatedPredicates)).build(), childDecorrelationResult.atMostSingleRow));
        }

        @Override
        public Optional<DecorrelationResult> visitLimit(LimitNode node, Void context) {
            Optional<DecorrelationResult> childDecorrelationResultOptional = PlanNodeDecorrelator.this.lookup.resolve(node.getSource()).accept(this, null);
            if (!childDecorrelationResultOptional.isPresent() || node.getCount() == 0L) {
                return Optional.empty();
            }
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            if (childDecorrelationResult.atMostSingleRow) {
                return childDecorrelationResultOptional;
            }
            if (node.getCount() != 1L) {
                return Optional.empty();
            }
            Set<Symbol> constantSymbols = childDecorrelationResult.getConstantSymbols();
            PlanNode decorrelatedChildNode = childDecorrelationResult.node;
            if (constantSymbols.isEmpty() || !constantSymbols.containsAll(decorrelatedChildNode.getOutputSymbols())) {
                return Optional.empty();
            }
            AggregationNode aggregationNode = new AggregationNode(PlanNodeDecorrelator.this.idAllocator.getNextId(), decorrelatedChildNode, (Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.of(), AggregationNode.singleGroupingSet(decorrelatedChildNode.getOutputSymbols()), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
            return Optional.of(new DecorrelationResult(aggregationNode, childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, true));
        }

        @Override
        public Optional<DecorrelationResult> visitEnforceSingleRow(EnforceSingleRowNode node, Void context) {
            Optional<DecorrelationResult> childDecorrelationResultOptional = PlanNodeDecorrelator.this.lookup.resolve(node.getSource()).accept(this, null);
            return childDecorrelationResultOptional.filter(result -> result.atMostSingleRow);
        }

        @Override
        public Optional<DecorrelationResult> visitAggregation(AggregationNode node, Void context) {
            if (node.hasEmptyGroupingSet()) {
                return Optional.empty();
            }
            Optional<DecorrelationResult> childDecorrelationResultOptional = PlanNodeDecorrelator.this.lookup.resolve(node.getSource()).accept(this, null);
            if (!childDecorrelationResultOptional.isPresent()) {
                return Optional.empty();
            }
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            Set<Symbol> constantSymbols = childDecorrelationResult.getConstantSymbols();
            AggregationNode decorrelatedAggregation = childDecorrelationResult.getCorrelatedSymbolMapper().map(node, childDecorrelationResult.node);
            ImmutableSet groupingKeys = ImmutableSet.copyOf(node.getGroupingKeys());
            List symbolsToAdd = (List)childDecorrelationResult.symbolsToPropagate.stream().filter(arg_0 -> DecorrelatingVisitor.lambda$visitAggregation$1((Set)groupingKeys, arg_0)).collect(ImmutableList.toImmutableList());
            if (!constantSymbols.containsAll(symbolsToAdd)) {
                return Optional.empty();
            }
            AggregationNode newAggregation = new AggregationNode(decorrelatedAggregation.getId(), decorrelatedAggregation.getSource(), decorrelatedAggregation.getAggregations(), AggregationNode.singleGroupingSet((List<Symbol>)ImmutableList.builder().addAll(node.getGroupingKeys()).addAll((Iterable)symbolsToAdd).build()), (List<Symbol>)ImmutableList.of(), decorrelatedAggregation.getStep(), decorrelatedAggregation.getHashSymbol(), decorrelatedAggregation.getGroupIdSymbol());
            boolean atMostSingleRow = newAggregation.getGroupingSetCount() == 1 && constantSymbols.containsAll(newAggregation.getGroupingKeys());
            return Optional.of(new DecorrelationResult(newAggregation, childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, atMostSingleRow));
        }

        @Override
        public Optional<DecorrelationResult> visitProject(ProjectNode node, Void context) {
            Optional<DecorrelationResult> childDecorrelationResultOptional = PlanNodeDecorrelator.this.lookup.resolve(node.getSource()).accept(this, null);
            if (!childDecorrelationResultOptional.isPresent()) {
                return Optional.empty();
            }
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            ImmutableSet nodeOutputSymbols = ImmutableSet.copyOf(node.getOutputSymbols());
            List symbolsToAdd = (List)childDecorrelationResult.symbolsToPropagate.stream().filter(arg_0 -> DecorrelatingVisitor.lambda$visitProject$2((Set)nodeOutputSymbols, arg_0)).collect(ImmutableList.toImmutableList());
            Assignments assignments = Assignments.builder().putAll(node.getAssignments()).putIdentities(symbolsToAdd).build();
            return Optional.of(new DecorrelationResult(new ProjectNode(PlanNodeDecorrelator.this.idAllocator.getNextId(), childDecorrelationResult.node, assignments), childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, childDecorrelationResult.atMostSingleRow));
        }

        private Multimap<Symbol, Symbol> extractCorrelatedSymbolsMapping(List<Expression> correlatedConjuncts) {
            ImmutableMultimap.Builder mapping = ImmutableMultimap.builder();
            for (Expression conjunct : correlatedConjuncts) {
                ComparisonExpression comparison;
                if (!(conjunct instanceof ComparisonExpression) || !((comparison = (ComparisonExpression)conjunct).getLeft() instanceof SymbolReference) || !(comparison.getRight() instanceof SymbolReference) || !comparison.getOperator().equals((Object)ComparisonExpression.Operator.EQUAL)) continue;
                Symbol left = Symbol.from(comparison.getLeft());
                Symbol right = Symbol.from(comparison.getRight());
                if (this.correlation.contains(left) && !this.correlation.contains(right)) {
                    mapping.put((Object)left, (Object)right);
                }
                if (!this.correlation.contains(right) || this.correlation.contains(left)) continue;
                mapping.put((Object)right, (Object)left);
            }
            return mapping.build();
        }

        private boolean isCorrelated(Expression expression) {
            return this.correlation.stream().anyMatch(SymbolsExtractor.extractUnique(expression)::contains);
        }

        private static /* synthetic */ boolean lambda$visitProject$2(Set nodeOutputSymbols, Symbol symbol) {
            return !nodeOutputSymbols.contains(symbol);
        }

        private static /* synthetic */ boolean lambda$visitAggregation$1(Set groupingKeys, Symbol symbol) {
            return !groupingKeys.contains(symbol);
        }
    }
}

