/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.sql.DynamicFilters;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.optimizations.PlanOptimizer;
import io.prestosql.sql.planner.plan.ChildReplacer;
import io.prestosql.sql.planner.plan.DynamicFilterId;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanVisitor;
import io.prestosql.sql.planner.plan.SpatialJoinNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.LogicalBinaryExpression;
import io.prestosql.sql.tree.SymbolReference;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class RemoveUnsupportedDynamicFilters
implements PlanOptimizer {
    private final Metadata metadata;

    public RemoveUnsupportedDynamicFilters(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        PlanWithConsumedDynamicFilters result = plan.accept(new Rewriter(), ImmutableSet.of());
        return result.getNode();
    }

    private static class PlanWithConsumedDynamicFilters {
        private final PlanNode node;
        private final Set<DynamicFilterId> consumedDynamicFilterIds;

        PlanWithConsumedDynamicFilters(PlanNode node, Set<DynamicFilterId> consumedDynamicFilterIds) {
            this.node = node;
            this.consumedDynamicFilterIds = ImmutableSet.copyOf(consumedDynamicFilterIds);
        }

        PlanNode getNode() {
            return this.node;
        }

        Set<DynamicFilterId> getConsumedDynamicFilterIds() {
            return this.consumedDynamicFilterIds;
        }
    }

    private class Rewriter
    extends PlanVisitor<PlanWithConsumedDynamicFilters, Set<DynamicFilterId>> {
        private Rewriter() {
        }

        @Override
        protected PlanWithConsumedDynamicFilters visitPlan(PlanNode node, Set<DynamicFilterId> allowedDynamicFilterIds) {
            List children = (List)node.getSources().stream().map(source -> source.accept(this, allowedDynamicFilterIds)).collect(ImmutableList.toImmutableList());
            PlanNode result = ChildReplacer.replaceChildren(node, children.stream().map(PlanWithConsumedDynamicFilters::getNode).collect(Collectors.toList()));
            Set consumedDynamicFilterIds = (Set)children.stream().map(PlanWithConsumedDynamicFilters::getConsumedDynamicFilterIds).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet());
            return new PlanWithConsumedDynamicFilters(result, consumedDynamicFilterIds);
        }

        @Override
        public PlanWithConsumedDynamicFilters visitJoin(JoinNode node, Set<DynamicFilterId> allowedDynamicFilterIds) {
            ImmutableSet allowedDynamicFilterIdsProbeSide = ImmutableSet.builder().addAll(node.getDynamicFilters().keySet()).addAll(allowedDynamicFilterIds).build();
            PlanWithConsumedDynamicFilters leftResult = node.getLeft().accept(this, allowedDynamicFilterIdsProbeSide);
            Set<DynamicFilterId> consumedProbeSide = leftResult.getConsumedDynamicFilterIds();
            Map dynamicFilters = (Map)node.getDynamicFilters().entrySet().stream().filter(entry -> consumedProbeSide.contains(entry.getKey())).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
            PlanWithConsumedDynamicFilters rightResult = node.getRight().accept(this, allowedDynamicFilterIds);
            HashSet<DynamicFilterId> consumed = new HashSet<DynamicFilterId>(rightResult.getConsumedDynamicFilterIds());
            consumed.addAll(consumedProbeSide);
            consumed.removeAll(dynamicFilters.keySet());
            Optional<Expression> filter = node.getFilter().map(this::removeAllDynamicFilters).filter(expression -> !expression.equals((Object)BooleanLiteral.TRUE_LITERAL));
            PlanNode left = leftResult.getNode();
            PlanNode right = rightResult.getNode();
            if (!(left.equals(node.getLeft()) && right.equals(node.getRight()) && dynamicFilters.equals(node.getDynamicFilters()) && filter.equals(node.getFilter()))) {
                return new PlanWithConsumedDynamicFilters(new JoinNode(node.getId(), node.getType(), left, right, node.getCriteria(), node.getLeftOutputSymbols(), node.getRightOutputSymbols(), filter, node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType(), node.isSpillable(), dynamicFilters, node.getReorderJoinStatsAndCost()), (Set<DynamicFilterId>)ImmutableSet.copyOf(consumed));
            }
            return new PlanWithConsumedDynamicFilters(node, (Set<DynamicFilterId>)ImmutableSet.copyOf(consumed));
        }

        @Override
        public PlanWithConsumedDynamicFilters visitSpatialJoin(SpatialJoinNode node, Set<DynamicFilterId> allowedDynamicFilterIds) {
            PlanWithConsumedDynamicFilters leftResult = node.getLeft().accept(this, allowedDynamicFilterIds);
            PlanWithConsumedDynamicFilters rightResult = node.getRight().accept(this, allowedDynamicFilterIds);
            ImmutableSet consumed = ImmutableSet.builder().addAll(leftResult.consumedDynamicFilterIds).addAll(rightResult.consumedDynamicFilterIds).build();
            Expression filter = this.removeAllDynamicFilters(node.getFilter());
            if (!node.getFilter().equals((Object)filter) || leftResult.getNode() != node.getLeft() || rightResult.getNode() != node.getRight()) {
                return new PlanWithConsumedDynamicFilters(new SpatialJoinNode(node.getId(), node.getType(), leftResult.getNode(), rightResult.getNode(), node.getOutputSymbols(), filter, node.getLeftPartitionSymbol(), node.getRightPartitionSymbol(), node.getKdbTree()), (Set<DynamicFilterId>)consumed);
            }
            return new PlanWithConsumedDynamicFilters(node, (Set<DynamicFilterId>)consumed);
        }

        @Override
        public PlanWithConsumedDynamicFilters visitFilter(FilterNode node, Set<DynamicFilterId> allowedDynamicFilterIds) {
            PlanWithConsumedDynamicFilters result = node.getSource().accept(this, allowedDynamicFilterIds);
            Expression original = node.getPredicate();
            ImmutableSet.Builder consumedDynamicFilterIds = ImmutableSet.builder().addAll(result.getConsumedDynamicFilterIds());
            PlanNode source = result.getNode();
            Expression modified = source instanceof TableScanNode ? this.removeDynamicFilters(original, allowedDynamicFilterIds, (ImmutableSet.Builder<DynamicFilterId>)consumedDynamicFilterIds) : this.removeAllDynamicFilters(original);
            if (BooleanLiteral.TRUE_LITERAL.equals((Object)modified)) {
                return new PlanWithConsumedDynamicFilters(source, (Set<DynamicFilterId>)consumedDynamicFilterIds.build());
            }
            if (!original.equals((Object)modified) || source != node.getSource()) {
                return new PlanWithConsumedDynamicFilters(new FilterNode(node.getId(), source, modified), (Set<DynamicFilterId>)consumedDynamicFilterIds.build());
            }
            return new PlanWithConsumedDynamicFilters(node, (Set<DynamicFilterId>)consumedDynamicFilterIds.build());
        }

        private Expression removeDynamicFilters(Expression expression, Set<DynamicFilterId> allowedDynamicFilterIds, ImmutableSet.Builder<DynamicFilterId> consumedDynamicFilterIds) {
            return ExpressionUtils.combineConjuncts(RemoveUnsupportedDynamicFilters.this.metadata, (Collection)ExpressionUtils.extractConjuncts(expression).stream().map(this::removeNestedDynamicFilters).filter(conjunct -> DynamicFilters.getDescriptor(conjunct).map(descriptor -> {
                if (descriptor.getInput() instanceof SymbolReference && allowedDynamicFilterIds.contains(descriptor.getId())) {
                    consumedDynamicFilterIds.add((Object)descriptor.getId());
                    return true;
                }
                return false;
            }).orElse(true)).collect(ImmutableList.toImmutableList()));
        }

        private Expression removeAllDynamicFilters(Expression expression) {
            Expression rewrittenExpression = this.removeNestedDynamicFilters(expression);
            DynamicFilters.ExtractResult extractResult = DynamicFilters.extractDynamicFilters(rewrittenExpression);
            if (extractResult.getDynamicConjuncts().isEmpty()) {
                return rewrittenExpression;
            }
            return ExpressionUtils.combineConjuncts(RemoveUnsupportedDynamicFilters.this.metadata, extractResult.getStaticConjuncts());
        }

        private Expression removeNestedDynamicFilters(Expression expression) {
            return ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)new ExpressionRewriter<Void>(){

                public Expression rewriteLogicalBinaryExpression(LogicalBinaryExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
                    LogicalBinaryExpression rewrittenNode = (LogicalBinaryExpression)treeRewriter.defaultRewrite((Expression)node, (Object)context);
                    boolean modified = node != rewrittenNode;
                    ImmutableList.Builder expressionBuilder = ImmutableList.builder();
                    if (DynamicFilters.isDynamicFilter(rewrittenNode.getLeft())) {
                        expressionBuilder.add((Object)BooleanLiteral.TRUE_LITERAL);
                        modified = true;
                    } else {
                        expressionBuilder.add((Object)rewrittenNode.getLeft());
                    }
                    if (DynamicFilters.isDynamicFilter(rewrittenNode.getRight())) {
                        expressionBuilder.add((Object)BooleanLiteral.TRUE_LITERAL);
                        modified = true;
                    } else {
                        expressionBuilder.add((Object)rewrittenNode.getRight());
                    }
                    if (!modified) {
                        return node;
                    }
                    return ExpressionUtils.combinePredicates(RemoveUnsupportedDynamicFilters.this.metadata, node.getOperator(), (Collection<Expression>)expressionBuilder.build());
                }
            }, (Expression)expression);
        }
    }
}

