/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolsExtractor;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.util.MorePredicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;
import java.util.stream.Collectors;

public class PlanNodeDecorrelator {
    private final PlanNodeIdAllocator idAllocator;
    private final Lookup lookup;

    public PlanNodeDecorrelator(PlanNodeIdAllocator idAllocator, Lookup lookup) {
        this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
        this.lookup = Objects.requireNonNull(lookup, "lookup is null");
    }

    public Optional<DecorrelatedNode> decorrelateFilters(PlanNode node, List<Symbol> correlation) {
        PlanNodeSearcher filterNodeSearcher = PlanNodeSearcher.searchFrom(node, this.lookup).where(FilterNode.class::isInstance).recurseOnlyWhen(MorePredicates.isInstanceOfAny(ProjectNode.class, LimitNode.class));
        List filterNodes = filterNodeSearcher.findAll();
        if (filterNodes.isEmpty()) {
            return this.decorrelatedNode((List<Expression>)ImmutableList.of(), node, correlation);
        }
        if (filterNodes.size() > 1) {
            return Optional.empty();
        }
        FilterNode filterNode = (FilterNode)filterNodes.get(0);
        Expression predicate = filterNode.getPredicate();
        if (!PlanNodeDecorrelator.isSupportedPredicate(predicate)) {
            return Optional.empty();
        }
        if (!SymbolsExtractor.extractUnique(predicate).containsAll(correlation)) {
            return Optional.empty();
        }
        Map<Boolean, List<Expression>> predicates = ExpressionUtils.extractConjuncts(predicate).stream().collect(Collectors.partitioningBy(this.isUsingPredicate(correlation)));
        ImmutableList correlatedPredicates = ImmutableList.copyOf((Collection)predicates.get(true));
        ImmutableList uncorrelatedPredicates = ImmutableList.copyOf((Collection)predicates.get(false));
        node = this.updateFilterNode(filterNodeSearcher, (List<Expression>)uncorrelatedPredicates);
        if (!correlatedPredicates.isEmpty()) {
            node = this.removeLimitNode(node);
        }
        node = this.ensureJoinSymbolsAreReturned(node, (List<Expression>)correlatedPredicates);
        return this.decorrelatedNode((List<Expression>)correlatedPredicates, node, correlation);
    }

    private static boolean isSupportedPredicate(Expression predicate) {
        AtomicBoolean isSupported = new AtomicBoolean(true);
        new DefaultTraversalVisitor<Void, AtomicBoolean>(){

            protected Void visitLogicalBinaryExpression(LogicalBinaryExpression node, AtomicBoolean context) {
                if (node.getType() != LogicalBinaryExpression.Type.AND) {
                    context.set(false);
                }
                return null;
            }
        }.process((Node)predicate, (Object)isSupported);
        return isSupported.get();
    }

    private Predicate<Expression> isUsingPredicate(List<Symbol> symbols) {
        return expression -> symbols.stream().anyMatch(SymbolsExtractor.extractUnique(expression)::contains);
    }

    private PlanNode updateFilterNode(PlanNodeSearcher filterNodeSearcher, List<Expression> newPredicates) {
        if (newPredicates.isEmpty()) {
            return filterNodeSearcher.removeAll();
        }
        FilterNode oldFilterNode = (FilterNode)Iterables.getOnlyElement(filterNodeSearcher.findAll());
        FilterNode newFilterNode = new FilterNode(this.idAllocator.getNextId(), oldFilterNode.getSource(), ExpressionUtils.combineConjuncts(newPredicates));
        return filterNodeSearcher.replaceAll(newFilterNode);
    }

    private PlanNode removeLimitNode(PlanNode node) {
        node = PlanNodeSearcher.searchFrom(node, this.lookup).where(LimitNode.class::isInstance).recurseOnlyWhen(ProjectNode.class::isInstance).removeFirst();
        return node;
    }

    private PlanNode ensureJoinSymbolsAreReturned(PlanNode scalarAggregationSource, List<Expression> joinPredicate) {
        Set<Symbol> joinExpressionSymbols = SymbolsExtractor.extractUnique(joinPredicate);
        ExtendProjectionRewriter extendProjectionRewriter = new ExtendProjectionRewriter(this.idAllocator, joinExpressionSymbols);
        return SimplePlanRewriter.rewriteWith(extendProjectionRewriter, scalarAggregationSource);
    }

    private Optional<DecorrelatedNode> decorrelatedNode(List<Expression> correlatedPredicates, PlanNode node, List<Symbol> correlation) {
        if (SymbolsExtractor.extractUnique(node, this.lookup).stream().anyMatch(correlation::contains)) {
            return Optional.empty();
        }
        return Optional.of(new DecorrelatedNode(correlatedPredicates, node));
    }

    private static class ExtendProjectionRewriter
    extends SimplePlanRewriter<PlanNode> {
        private final PlanNodeIdAllocator idAllocator;
        private final Set<Symbol> symbols;

        ExtendProjectionRewriter(PlanNodeIdAllocator idAllocator, Set<Symbol> symbols) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbols = Objects.requireNonNull(symbols, "symbols is null");
        }

        @Override
        public PlanNode visitProject(ProjectNode node, SimplePlanRewriter.RewriteContext<PlanNode> context) {
            ProjectNode rewrittenNode = (ProjectNode)context.defaultRewrite(node, context.get());
            List symbolsToAdd = (List)this.symbols.stream().filter(rewrittenNode.getSource().getOutputSymbols()::contains).filter(symbol -> !rewrittenNode.getOutputSymbols().contains(symbol)).collect(ImmutableList.toImmutableList());
            Assignments assignments = Assignments.builder().putAll(rewrittenNode.getAssignments()).putIdentities(symbolsToAdd).build();
            return new ProjectNode(this.idAllocator.getNextId(), rewrittenNode.getSource(), assignments);
        }
    }

    public static class DecorrelatedNode {
        private final List<Expression> correlatedPredicates;
        private final PlanNode node;

        public DecorrelatedNode(List<Expression> correlatedPredicates, PlanNode node) {
            Objects.requireNonNull(correlatedPredicates, "correlatedPredicates is null");
            this.correlatedPredicates = ImmutableList.copyOf(correlatedPredicates);
            this.node = Objects.requireNonNull(node, "node is null");
        }

        Optional<Expression> getCorrelatedPredicates() {
            if (this.correlatedPredicates.isEmpty()) {
                return Optional.empty();
            }
            return Optional.of(ExpressionUtils.and(this.correlatedPredicates));
        }

        public PlanNode getNode() {
            return this.node;
        }
    }
}

