/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.block.SortOrder;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.planner.OrderingScheme;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.iterative.GroupReference;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.optimizations.SymbolMapper;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.EnforceSingleRowNode;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.LimitNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanNodeId;
import io.prestosql.sql.planner.plan.PlanVisitor;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.RowNumberNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.planner.plan.TopNRowNumberNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.SymbolReference;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class PlanNodeDecorrelator {
    private final Metadata metadata;
    private final SymbolAllocator symbolAllocator;
    private final Lookup lookup;

    public PlanNodeDecorrelator(Metadata metadata, SymbolAllocator symbolAllocator, Lookup lookup) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        this.lookup = Objects.requireNonNull(lookup, "lookup is null");
    }

    public Optional<DecorrelatedNode> decorrelateFilters(PlanNode node, List<Symbol> correlation) {
        if (correlation.isEmpty()) {
            return Optional.of(new DecorrelatedNode((List<Expression>)ImmutableList.of(), node));
        }
        Optional<DecorrelationResult> decorrelationResultOptional = node.accept(new DecorrelatingVisitor(this.metadata, correlation), null);
        return decorrelationResultOptional.flatMap(decorrelationResult -> this.decorrelatedNode(decorrelationResult.correlatedPredicates, decorrelationResult.node, correlation));
    }

    private Optional<DecorrelatedNode> decorrelatedNode(List<Expression> correlatedPredicates, PlanNode node, List<Symbol> correlation) {
        if (this.containsCorrelation(node, correlation)) {
            return Optional.empty();
        }
        return Optional.of(new DecorrelatedNode(correlatedPredicates, node));
    }

    private boolean containsCorrelation(PlanNode node, List<Symbol> correlation) {
        return Sets.union(SymbolsExtractor.extractUnique(node, this.lookup), SymbolsExtractor.extractOutputSymbols(node, this.lookup)).stream().anyMatch(correlation::contains);
    }

    public static class DecorrelatedNode {
        private final List<Expression> correlatedPredicates;
        private final PlanNode node;

        public DecorrelatedNode(List<Expression> correlatedPredicates, PlanNode node) {
            Objects.requireNonNull(correlatedPredicates, "correlatedPredicates is null");
            this.correlatedPredicates = ImmutableList.copyOf(correlatedPredicates);
            this.node = Objects.requireNonNull(node, "node is null");
        }

        public Optional<Expression> getCorrelatedPredicates() {
            if (this.correlatedPredicates.isEmpty()) {
                return Optional.empty();
            }
            return Optional.of(ExpressionUtils.and(this.correlatedPredicates));
        }

        public PlanNode getNode() {
            return this.node;
        }
    }

    private static class DecorrelationResult {
        final PlanNode node;
        final Set<Symbol> symbolsToPropagate;
        final List<Expression> correlatedPredicates;
        final Multimap<Symbol, Symbol> correlatedSymbolsMapping;
        final boolean atMostSingleRow;

        DecorrelationResult(PlanNode node, Set<Symbol> symbolsToPropagate, List<Expression> correlatedPredicates, Multimap<Symbol, Symbol> correlatedSymbolsMapping, boolean atMostSingleRow) {
            this.node = node;
            this.symbolsToPropagate = symbolsToPropagate;
            this.correlatedPredicates = correlatedPredicates;
            this.atMostSingleRow = atMostSingleRow;
            this.correlatedSymbolsMapping = correlatedSymbolsMapping;
            Preconditions.checkState((boolean)symbolsToPropagate.containsAll(correlatedSymbolsMapping.values()), (Object)"Expected symbols to propagate to contain all constant symbols");
        }

        SymbolMapper getCorrelatedSymbolMapper() {
            return new SymbolMapper((Map)this.correlatedSymbolsMapping.asMap().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, symbols -> (Symbol)Iterables.getLast((Iterable)((Iterable)symbols.getValue())))));
        }

        Set<Symbol> getConstantSymbols() {
            return ImmutableSet.copyOf((Collection)this.correlatedSymbolsMapping.values());
        }
    }

    private class DecorrelatingVisitor
    extends PlanVisitor<Optional<DecorrelationResult>, Void> {
        private final Metadata metadata;
        private final List<Symbol> correlation;

        DecorrelatingVisitor(Metadata metadata, List<Symbol> correlation) {
            this.metadata = metadata;
            this.correlation = Objects.requireNonNull(correlation, "correlation is null");
        }

        @Override
        protected Optional<DecorrelationResult> visitPlan(PlanNode node, Void context) {
            if (PlanNodeDecorrelator.this.containsCorrelation(node, this.correlation)) {
                return Optional.empty();
            }
            return Optional.of(new DecorrelationResult(node, (Set<Symbol>)ImmutableSet.of(), (List<Expression>)ImmutableList.of(), (Multimap<Symbol, Symbol>)ImmutableMultimap.of(), false));
        }

        @Override
        public Optional<DecorrelationResult> visitGroupReference(GroupReference node, Void context) {
            return PlanNodeDecorrelator.this.lookup.resolve(node).accept(this, null);
        }

        @Override
        public Optional<DecorrelationResult> visitFilter(FilterNode node, Void context) {
            Optional<DecorrelationResult> childDecorrelationResultOptional = Optional.of(new DecorrelationResult(node.getSource(), (Set<Symbol>)ImmutableSet.of(), (List<Expression>)ImmutableList.of(), (Multimap<Symbol, Symbol>)ImmutableMultimap.of(), false));
            if (PlanNodeDecorrelator.this.containsCorrelation(node.getSource(), this.correlation)) {
                childDecorrelationResultOptional = node.getSource().accept(this, null);
            }
            if (childDecorrelationResultOptional.isEmpty()) {
                return Optional.empty();
            }
            Expression predicate = node.getPredicate();
            Map<Boolean, List<Expression>> predicates = ExpressionUtils.extractConjuncts(predicate).stream().collect(Collectors.partitioningBy(this::isCorrelated));
            ImmutableList correlatedPredicates = ImmutableList.copyOf((Collection)predicates.get(true));
            ImmutableList uncorrelatedPredicates = ImmutableList.copyOf((Collection)predicates.get(false));
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            FilterNode newFilterNode = new FilterNode(node.getId(), childDecorrelationResult.node, ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>)uncorrelatedPredicates));
            Sets.SetView symbolsToPropagate = Sets.difference(SymbolsExtractor.extractUnique((Iterable<? extends Expression>)correlatedPredicates), (Set)ImmutableSet.copyOf(this.correlation));
            return Optional.of(new DecorrelationResult(newFilterNode, (Set<Symbol>)Sets.union(childDecorrelationResult.symbolsToPropagate, (Set)symbolsToPropagate), (List<Expression>)ImmutableList.builder().addAll(childDecorrelationResult.correlatedPredicates).addAll((Iterable)correlatedPredicates).build(), (Multimap<Symbol, Symbol>)ImmutableMultimap.builder().putAll(childDecorrelationResult.correlatedSymbolsMapping).putAll(this.extractCorrelatedSymbolsMapping((List<Expression>)correlatedPredicates)).build(), childDecorrelationResult.atMostSingleRow));
        }

        @Override
        public Optional<DecorrelationResult> visitLimit(LimitNode node, Void context) {
            if (node.getCount() == 0L || node.isWithTies()) {
                return Optional.empty();
            }
            Optional<DecorrelationResult> childDecorrelationResultOptional = node.getSource().accept(this, null);
            if (childDecorrelationResultOptional.isEmpty()) {
                return Optional.empty();
            }
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            if (childDecorrelationResult.atMostSingleRow) {
                return childDecorrelationResultOptional;
            }
            if (node.getCount() == 1L) {
                return this.rewriteLimitWithRowCountOne(childDecorrelationResult, node.getId());
            }
            return this.rewriteLimitWithRowCountGreaterThanOne(childDecorrelationResult, node);
        }

        private Optional<DecorrelationResult> rewriteLimitWithRowCountOne(DecorrelationResult childDecorrelationResult, PlanNodeId nodeId) {
            PlanNode decorrelatedChildNode = childDecorrelationResult.node;
            Set<Symbol> constantSymbols = childDecorrelationResult.getConstantSymbols();
            if (constantSymbols.isEmpty() || !constantSymbols.containsAll(decorrelatedChildNode.getOutputSymbols())) {
                return Optional.empty();
            }
            AggregationNode aggregationNode = new AggregationNode(nodeId, decorrelatedChildNode, (Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.of(), AggregationNode.singleGroupingSet(decorrelatedChildNode.getOutputSymbols()), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
            return Optional.of(new DecorrelationResult(aggregationNode, childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, true));
        }

        private Optional<DecorrelationResult> rewriteLimitWithRowCountGreaterThanOne(DecorrelationResult childDecorrelationResult, LimitNode node) {
            PlanNode decorrelatedChildNode = childDecorrelationResult.node;
            if (childDecorrelationResult.symbolsToPropagate.isEmpty()) {
                return Optional.of(new DecorrelationResult(node.replaceChildren((List<PlanNode>)ImmutableList.of((Object)decorrelatedChildNode)), childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, false));
            }
            Set<Symbol> constantSymbols = childDecorrelationResult.getConstantSymbols();
            if (!constantSymbols.containsAll(childDecorrelationResult.symbolsToPropagate)) {
                return Optional.empty();
            }
            RowNumberNode rowNumberNode = new RowNumberNode(node.getId(), decorrelatedChildNode, (List<Symbol>)ImmutableList.copyOf(childDecorrelationResult.symbolsToPropagate), false, PlanNodeDecorrelator.this.symbolAllocator.newSymbol("row_number", (Type)BigintType.BIGINT), Optional.of(Math.toIntExact(node.getCount())), Optional.empty());
            return Optional.of(new DecorrelationResult(rowNumberNode, childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, false));
        }

        @Override
        public Optional<DecorrelationResult> visitTopN(TopNNode node, Void context) {
            if (node.getCount() == 0L) {
                return Optional.empty();
            }
            Optional<DecorrelationResult> childDecorrelationResultOptional = node.getSource().accept(this, null);
            if (childDecorrelationResultOptional.isEmpty()) {
                return Optional.empty();
            }
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            if (childDecorrelationResult.atMostSingleRow) {
                return childDecorrelationResultOptional;
            }
            PlanNode decorrelatedChildNode = childDecorrelationResult.node;
            Set<Symbol> constantSymbols = childDecorrelationResult.getConstantSymbols();
            Optional<OrderingScheme> decorrelatedOrderingScheme = this.decorrelateOrderingScheme(node.getOrderingScheme(), constantSymbols);
            if (childDecorrelationResult.symbolsToPropagate.isEmpty()) {
                return decorrelatedOrderingScheme.map(orderingScheme -> new DecorrelationResult(new TopNNode(node.getId(), decorrelatedChildNode, node.getCount(), (OrderingScheme)orderingScheme, node.getStep()), childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, node.getCount() == 1L)).or(() -> Optional.of(new DecorrelationResult(new LimitNode(node.getId(), decorrelatedChildNode, node.getCount(), false), childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, node.getCount() == 1L)));
            }
            if (!constantSymbols.containsAll(childDecorrelationResult.symbolsToPropagate)) {
                return Optional.empty();
            }
            return decorrelatedOrderingScheme.map(orderingScheme -> {
                TopNRowNumberNode topNRowNumberNode = new TopNRowNumberNode(node.getId(), decorrelatedChildNode, new WindowNode.Specification((List<Symbol>)ImmutableList.copyOf(childDecorrelationResult.symbolsToPropagate), Optional.of(orderingScheme)), PlanNodeDecorrelator.this.symbolAllocator.newSymbol("row_number", (Type)BigintType.BIGINT), Math.toIntExact(node.getCount()), false, Optional.empty());
                return Optional.of(new DecorrelationResult(topNRowNumberNode, childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, node.getCount() == 1L));
            }).orElseGet(() -> {
                RowNumberNode rowNumberNode = new RowNumberNode(node.getId(), decorrelatedChildNode, (List<Symbol>)ImmutableList.copyOf(childDecorrelationResult.symbolsToPropagate), false, PlanNodeDecorrelator.this.symbolAllocator.newSymbol("row_number", (Type)BigintType.BIGINT), Optional.of(Math.toIntExact(node.getCount())), Optional.empty());
                return Optional.of(new DecorrelationResult(rowNumberNode, childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, node.getCount() == 1L));
            });
        }

        private Optional<OrderingScheme> decorrelateOrderingScheme(OrderingScheme orderingScheme, Set<Symbol> constantSymbols) {
            ImmutableList.Builder nonConstantOrderBy = ImmutableList.builder();
            ImmutableMap.Builder nonConstantOrderings = ImmutableMap.builder();
            for (Symbol symbol : orderingScheme.getOrderBy()) {
                if (constantSymbols.contains(symbol) || this.correlation.contains(symbol)) continue;
                nonConstantOrderBy.add((Object)symbol);
                nonConstantOrderings.put((Object)symbol, (Object)orderingScheme.getOrdering(symbol));
            }
            if (nonConstantOrderBy.build().isEmpty()) {
                return Optional.empty();
            }
            return Optional.of(new OrderingScheme((List<Symbol>)nonConstantOrderBy.build(), (Map<Symbol, SortOrder>)nonConstantOrderings.build()));
        }

        @Override
        public Optional<DecorrelationResult> visitEnforceSingleRow(EnforceSingleRowNode node, Void context) {
            Optional<DecorrelationResult> childDecorrelationResultOptional = node.getSource().accept(this, null);
            return childDecorrelationResultOptional.filter(result -> result.atMostSingleRow);
        }

        @Override
        public Optional<DecorrelationResult> visitAggregation(AggregationNode node, Void context) {
            if (node.hasEmptyGroupingSet()) {
                return Optional.empty();
            }
            if (node.getGroupingSetCount() != 1) {
                return Optional.empty();
            }
            Optional<DecorrelationResult> childDecorrelationResultOptional = node.getSource().accept(this, null);
            if (childDecorrelationResultOptional.isEmpty()) {
                return Optional.empty();
            }
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            Set<Symbol> constantSymbols = childDecorrelationResult.getConstantSymbols();
            AggregationNode decorrelatedAggregation = childDecorrelationResult.getCorrelatedSymbolMapper().map(node, childDecorrelationResult.node);
            ImmutableSet groupingKeys = ImmutableSet.copyOf(node.getGroupingKeys());
            List symbolsToAdd = (List)childDecorrelationResult.symbolsToPropagate.stream().filter(arg_0 -> DecorrelatingVisitor.lambda$visitAggregation$5((Set)groupingKeys, arg_0)).collect(ImmutableList.toImmutableList());
            if (!constantSymbols.containsAll(symbolsToAdd)) {
                return Optional.empty();
            }
            AggregationNode newAggregation = new AggregationNode(decorrelatedAggregation.getId(), decorrelatedAggregation.getSource(), decorrelatedAggregation.getAggregations(), AggregationNode.singleGroupingSet((List<Symbol>)ImmutableList.builder().addAll(node.getGroupingKeys()).addAll((Iterable)symbolsToAdd).build()), (List<Symbol>)ImmutableList.of(), decorrelatedAggregation.getStep(), decorrelatedAggregation.getHashSymbol(), decorrelatedAggregation.getGroupIdSymbol());
            return Optional.of(new DecorrelationResult(newAggregation, childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, constantSymbols.containsAll(newAggregation.getGroupingKeys())));
        }

        @Override
        public Optional<DecorrelationResult> visitProject(ProjectNode node, Void context) {
            Optional<DecorrelationResult> childDecorrelationResultOptional = node.getSource().accept(this, null);
            if (childDecorrelationResultOptional.isEmpty()) {
                return Optional.empty();
            }
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            ImmutableSet nodeOutputSymbols = ImmutableSet.copyOf(node.getOutputSymbols());
            List symbolsToAdd = (List)childDecorrelationResult.symbolsToPropagate.stream().filter(arg_0 -> DecorrelatingVisitor.lambda$visitProject$6((Set)nodeOutputSymbols, arg_0)).collect(ImmutableList.toImmutableList());
            Assignments assignments = Assignments.builder().putAll(node.getAssignments()).putIdentities(symbolsToAdd).build();
            return Optional.of(new DecorrelationResult(new ProjectNode(node.getId(), childDecorrelationResult.node, assignments), childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, childDecorrelationResult.atMostSingleRow));
        }

        private Multimap<Symbol, Symbol> extractCorrelatedSymbolsMapping(List<Expression> correlatedConjuncts) {
            ImmutableMultimap.Builder mapping = ImmutableMultimap.builder();
            for (Expression conjunct : correlatedConjuncts) {
                ComparisonExpression comparison;
                if (!(conjunct instanceof ComparisonExpression) || !((comparison = (ComparisonExpression)conjunct).getLeft() instanceof SymbolReference) || !(comparison.getRight() instanceof SymbolReference) || comparison.getOperator() != ComparisonExpression.Operator.EQUAL) continue;
                Symbol left = Symbol.from(comparison.getLeft());
                Symbol right = Symbol.from(comparison.getRight());
                if (this.correlation.contains(left) && !this.correlation.contains(right)) {
                    mapping.put((Object)left, (Object)right);
                }
                if (!this.correlation.contains(right) || this.correlation.contains(left)) continue;
                mapping.put((Object)right, (Object)left);
            }
            return mapping.build();
        }

        private boolean isCorrelated(Expression expression) {
            return this.correlation.stream().anyMatch(SymbolsExtractor.extractUnique(expression)::contains);
        }

        private static /* synthetic */ boolean lambda$visitProject$6(Set nodeOutputSymbols, Symbol symbol) {
            return !nodeOutputSymbols.contains(symbol);
        }

        private static /* synthetic */ boolean lambda$visitAggregation$5(Set groupingKeys, Symbol symbol) {
            return !groupingKeys.contains(symbol);
        }
    }
}

