/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.DependencyExtractor;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.Predicates;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.util.ImmutableCollectors;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;
import java.util.stream.Collectors;

public class TransformCorrelatedScalarAggregationToJoin
implements PlanOptimizer {
    private final Metadata metadata;

    public TransformCorrelatedScalarAggregationToJoin(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        return SimplePlanRewriter.rewriteWith(new Rewriter(idAllocator, symbolAllocator, this.metadata), plan, null);
    }

    private static Map<Symbol, Expression> toAssignments(Collection<Symbol> symbols) {
        return (Map)symbols.stream().collect(ImmutableCollectors.toImmutableMap(s -> s, Symbol::toSymbolReference));
    }

    private static class ExtendProjectionRewriter
    extends SimplePlanRewriter<PlanNode> {
        private final PlanNodeIdAllocator idAllocator;
        private final Set<Symbol> symbols;

        ExtendProjectionRewriter(PlanNodeIdAllocator idAllocator, Set<Symbol> symbols) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbols = Objects.requireNonNull(symbols, "symbols is null");
        }

        @Override
        public PlanNode visitProject(ProjectNode node, SimplePlanRewriter.RewriteContext<PlanNode> context) {
            ProjectNode rewrittenNode = (ProjectNode)context.defaultRewrite(node, context.get());
            List symbolsToAdd = (List)this.symbols.stream().filter(rewrittenNode.getSource().getOutputSymbols()::contains).filter(symbol -> !rewrittenNode.getOutputSymbols().contains(symbol)).collect(ImmutableCollectors.toImmutableList());
            ImmutableMap assignments = ImmutableMap.builder().putAll(rewrittenNode.getAssignments()).putAll(TransformCorrelatedScalarAggregationToJoin.toAssignments(symbolsToAdd)).build();
            return new ProjectNode(this.idAllocator.getNextId(), rewrittenNode.getSource(), (Map<Symbol, Expression>)assignments);
        }
    }

    private static class DecorrelatedNode {
        private final List<Expression> correlatedPredicates;
        private final PlanNode node;

        public DecorrelatedNode(List<Expression> correlatedPredicates, PlanNode node) {
            Objects.requireNonNull(correlatedPredicates, "correlatedPredicates is null");
            this.correlatedPredicates = ImmutableList.copyOf(correlatedPredicates);
            this.node = Objects.requireNonNull(node, "node is null");
        }

        Optional<Expression> getCorrelatedPredicates() {
            if (this.correlatedPredicates.isEmpty()) {
                return Optional.empty();
            }
            return Optional.of(ExpressionUtils.and(this.correlatedPredicates));
        }

        public PlanNode getNode() {
            return this.node;
        }
    }

    private static class Rewriter
    extends SimplePlanRewriter<PlanNode> {
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;
        private final Metadata metadata;

        public Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Metadata metadata) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override
        public PlanNode visitApply(ApplyNode node, SimplePlanRewriter.RewriteContext<PlanNode> context) {
            ApplyNode rewrittenNode = (ApplyNode)context.defaultRewrite(node, context.get());
            if (!rewrittenNode.getCorrelation().isEmpty()) {
                Optional aggregation = PlanNodeSearcher.searchFrom(rewrittenNode.getSubquery()).where(AggregationNode.class::isInstance).skipOnlyWhen(Predicates.isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)).findFirst();
                if (aggregation.isPresent() && ((AggregationNode)aggregation.get()).getGroupingKeys().isEmpty()) {
                    return this.rewriteScalarAggregation(rewrittenNode, (AggregationNode)aggregation.get());
                }
            }
            return rewrittenNode;
        }

        private PlanNode rewriteScalarAggregation(ApplyNode apply, AggregationNode aggregation) {
            List<Symbol> correlation = apply.getCorrelation();
            Optional<DecorrelatedNode> source = this.decorrelateFilters(aggregation.getSource(), correlation);
            if (!source.isPresent()) {
                return apply;
            }
            Symbol nonNull = this.symbolAllocator.newSymbol("non_null", (Type)BooleanType.BOOLEAN);
            ImmutableMap scalarAggregationSourceAssignments = ImmutableMap.builder().putAll(TransformCorrelatedScalarAggregationToJoin.toAssignments(source.get().getNode().getOutputSymbols())).put((Object)nonNull, (Object)BooleanLiteral.TRUE_LITERAL).build();
            ProjectNode scalarAggregationSourceWithNonNullableSymbol = new ProjectNode(this.idAllocator.getNextId(), source.get().getNode(), (Map<Symbol, Expression>)scalarAggregationSourceAssignments);
            return this.rewriteScalarAggregation(apply, aggregation, scalarAggregationSourceWithNonNullableSymbol, source.get().getCorrelatedPredicates(), nonNull);
        }

        private PlanNode rewriteScalarAggregation(ApplyNode applyNode, AggregationNode scalarAggregation, PlanNode scalarAggregationSource, Optional<Expression> joinExpression, Symbol nonNull) {
            AssignUniqueId inputWithUniqueColumns = new AssignUniqueId(this.idAllocator.getNextId(), applyNode.getInput(), this.symbolAllocator.newSymbol("unique", (Type)BigintType.BIGINT));
            JoinNode leftOuterJoin = new JoinNode(this.idAllocator.getNextId(), JoinNode.Type.LEFT, inputWithUniqueColumns, scalarAggregationSource, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), joinExpression, Optional.empty(), Optional.empty());
            Optional<AggregationNode> aggregationNode = this.createAggregationNode(scalarAggregation, leftOuterJoin, nonNull);
            if (!aggregationNode.isPresent()) {
                return applyNode;
            }
            Optional subqueryProjection = PlanNodeSearcher.searchFrom(applyNode.getSubquery()).where(ProjectNode.class::isInstance).findFirst();
            if (subqueryProjection.isPresent()) {
                ImmutableMap assignments = ImmutableMap.builder().putAll(TransformCorrelatedScalarAggregationToJoin.toAssignments(aggregationNode.get().getOutputSymbols())).putAll(((ProjectNode)subqueryProjection.get()).getAssignments()).build();
                return new ProjectNode(this.idAllocator.getNextId(), aggregationNode.get(), (Map<Symbol, Expression>)assignments);
            }
            return aggregationNode.get();
        }

        private Optional<AggregationNode> createAggregationNode(AggregationNode scalarAggregation, JoinNode leftOuterJoin, Symbol nonNullableAggregationSourceSymbol) {
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            ImmutableMap.Builder functions = ImmutableMap.builder();
            FunctionRegistry functionRegistry = this.metadata.getFunctionRegistry();
            for (Map.Entry<Symbol, FunctionCall> entry : scalarAggregation.getAggregations().entrySet()) {
                FunctionCall call = entry.getValue();
                QualifiedName count = QualifiedName.of((String)"count");
                Symbol symbol = entry.getKey();
                if (call.getName().equals((Object)count)) {
                    aggregations.put((Object)symbol, (Object)new FunctionCall(count, (List)ImmutableList.of((Object)nonNullableAggregationSourceSymbol.toSymbolReference())));
                    ImmutableList scalarAggregationSourceTypeSignatures = ImmutableList.of((Object)this.symbolAllocator.getTypes().get(nonNullableAggregationSourceSymbol).getTypeSignature());
                    functions.put((Object)symbol, (Object)functionRegistry.resolveFunction(count, (List<TypeSignature>)scalarAggregationSourceTypeSignatures));
                    continue;
                }
                aggregations.put((Object)symbol, (Object)entry.getValue());
                functions.put((Object)symbol, (Object)scalarAggregation.getFunctions().get(symbol));
            }
            List<Symbol> groupBySymbols = leftOuterJoin.getLeft().getOutputSymbols();
            return Optional.of(new AggregationNode(this.idAllocator.getNextId(), leftOuterJoin, (Map<Symbol, FunctionCall>)aggregations.build(), (Map<Symbol, Signature>)functions.build(), scalarAggregation.getMasks(), (List<List<Symbol>>)ImmutableList.of(groupBySymbols), scalarAggregation.getStep(), scalarAggregation.getHashSymbol(), Optional.empty()));
        }

        private Optional<DecorrelatedNode> decorrelateFilters(PlanNode node, List<Symbol> correlation) {
            PlanNodeSearcher filterNodeSearcher = PlanNodeSearcher.searchFrom(node).where(FilterNode.class::isInstance).skipOnlyWhen(Predicates.isInstanceOfAny(ProjectNode.class, LimitNode.class));
            List filterNodes = filterNodeSearcher.findAll();
            if (filterNodes.isEmpty()) {
                return this.decorrelatedNode((List<Expression>)ImmutableList.of(), node, correlation);
            }
            if (filterNodes.size() > 1) {
                return Optional.empty();
            }
            FilterNode filterNode = (FilterNode)filterNodes.get(0);
            Expression predicate = filterNode.getPredicate();
            if (!Rewriter.isSupportedPredicate(predicate)) {
                return Optional.empty();
            }
            if (!DependencyExtractor.extractUnique(predicate).containsAll(correlation)) {
                return Optional.empty();
            }
            Map<Boolean, List<Expression>> predicates = ExpressionUtils.extractConjuncts(predicate).stream().collect(Collectors.partitioningBy(this.isUsingPredicate(correlation)));
            ImmutableList correlatedPredicates = ImmutableList.copyOf((Collection)predicates.get(true));
            ImmutableList uncorrelatedPredicates = ImmutableList.copyOf((Collection)predicates.get(false));
            node = this.updateFilterNode(filterNodeSearcher, (List<Expression>)uncorrelatedPredicates);
            if (!correlatedPredicates.isEmpty()) {
                node = this.removeLimitNode(node);
            }
            node = this.ensureJoinSymbolsAreReturned(node, (List<Expression>)correlatedPredicates);
            return this.decorrelatedNode((List<Expression>)correlatedPredicates, node, correlation);
        }

        private Optional<DecorrelatedNode> decorrelatedNode(List<Expression> correlatedPredicates, PlanNode node, List<Symbol> correlation) {
            if (DependencyExtractor.extractUnique(node).stream().anyMatch(correlation::contains)) {
                return Optional.empty();
            }
            return Optional.of(new DecorrelatedNode(correlatedPredicates, node));
        }

        private Predicate<Expression> isUsingPredicate(List<Symbol> symbols) {
            return expression -> symbols.stream().anyMatch(DependencyExtractor.extractUnique(expression)::contains);
        }

        private PlanNode updateFilterNode(PlanNodeSearcher filterNodeSearcher, List<Expression> newPredicates) {
            if (newPredicates.isEmpty()) {
                return filterNodeSearcher.removeAll();
            }
            FilterNode oldFilterNode = (FilterNode)Iterables.getOnlyElement(filterNodeSearcher.findAll());
            FilterNode newFilterNode = new FilterNode(this.idAllocator.getNextId(), oldFilterNode.getSource(), ExpressionUtils.combineConjuncts(newPredicates));
            return filterNodeSearcher.replaceAll(newFilterNode);
        }

        private PlanNode removeLimitNode(PlanNode node) {
            node = PlanNodeSearcher.searchFrom(node).where(LimitNode.class::isInstance).skipOnlyWhen(ProjectNode.class::isInstance).removeFirst();
            return node;
        }

        private PlanNode ensureJoinSymbolsAreReturned(PlanNode scalarAggregationSource, List<Expression> joinPredicate) {
            Set<Symbol> joinExpressionSymbols = DependencyExtractor.extractUnique(joinPredicate);
            ExtendProjectionRewriter extendProjectionRewriter = new ExtendProjectionRewriter(this.idAllocator, joinExpressionSymbols);
            return Rewriter.rewriteWith(extendProjectionRewriter, scalarAggregationSource);
        }

        private static boolean isSupportedPredicate(Expression predicate) {
            AtomicBoolean isSupported = new AtomicBoolean(true);
            new DefaultTraversalVisitor<Void, AtomicBoolean>(){

                protected Void visitLogicalBinaryExpression(LogicalBinaryExpression node, AtomicBoolean context) {
                    if (node.getType() != LogicalBinaryExpression.Type.AND) {
                        context.set(false);
                    }
                    return null;
                }
            }.process((Node)predicate, (Object)isSupported);
            return isSupported.get();
        }
    }
}

