/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.DeterminismEvaluator;
import com.facebook.presto.sql.planner.ExpressionInterpreter;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.planner.NoOpSymbolResolver;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.util.ImmutableCollectors;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

public class SimplifyExpressions
implements PlanOptimizer {
    private final Metadata metadata;
    private final SqlParser sqlParser;

    public SimplifyExpressions(Metadata metadata, SqlParser sqlParser) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.sqlParser = Objects.requireNonNull(sqlParser, "sqlParser is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        Objects.requireNonNull(plan, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(types, "types is null");
        Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        Objects.requireNonNull(idAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(this.metadata, this.sqlParser, session, types, idAllocator), plan);
    }

    private static class ExtractCommonPredicatesExpressionRewriter
    extends ExpressionRewriter<NodeContext> {
        private ExtractCommonPredicatesExpressionRewriter() {
        }

        public Expression rewriteExpression(Expression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter) {
            if (context.isRootNode()) {
                return treeRewriter.rewrite(node, (Object)NodeContext.NOT_ROOT_NODE);
            }
            return null;
        }

        public Expression rewriteLogicalBinaryExpression(LogicalBinaryExpression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter) {
            Expression expression = ExpressionUtils.combinePredicates(node.getType(), (Collection)ExpressionUtils.extractPredicates(node.getType(), (Expression)node).stream().map(subExpression -> treeRewriter.rewrite(subExpression, (Object)NodeContext.NOT_ROOT_NODE)).collect(ImmutableCollectors.toImmutableList()));
            if (!(expression instanceof LogicalBinaryExpression)) {
                return expression;
            }
            Expression simplified = ExtractCommonPredicatesExpressionRewriter.extractCommonPredicates((LogicalBinaryExpression)expression);
            if (context.isRootNode() && simplified instanceof LogicalBinaryExpression && ((LogicalBinaryExpression)simplified).getType() == LogicalBinaryExpression.Type.OR) {
                return ExtractCommonPredicatesExpressionRewriter.distributeIfPossible((LogicalBinaryExpression)simplified);
            }
            return simplified;
        }

        private static Expression extractCommonPredicates(LogicalBinaryExpression node) {
            List<List<Expression>> subPredicates = ExtractCommonPredicatesExpressionRewriter.getSubPredicates(node);
            ImmutableSet commonPredicates = ImmutableSet.copyOf((Collection)subPredicates.stream().map(ExtractCommonPredicatesExpressionRewriter::filterDeterministicPredicates).reduce(Sets::intersection).orElse(Collections.emptySet()));
            List uncorrelatedSubPredicates = (List)subPredicates.stream().map(arg_0 -> ExtractCommonPredicatesExpressionRewriter.lambda$extractCommonPredicates$1((Set)commonPredicates, arg_0)).collect(ImmutableCollectors.toImmutableList());
            LogicalBinaryExpression.Type flippedNodeType = node.getType().flip();
            List uncorrelatedPredicates = (List)uncorrelatedSubPredicates.stream().map(predicate -> ExpressionUtils.combinePredicates(flippedNodeType, predicate)).collect(ImmutableCollectors.toImmutableList());
            Expression combinedUncorrelatedPredicates = ExpressionUtils.combinePredicates(node.getType(), uncorrelatedPredicates);
            return ExpressionUtils.combinePredicates(flippedNodeType, (Collection<Expression>)ImmutableList.builder().addAll((Iterable)commonPredicates).add((Object)combinedUncorrelatedPredicates).build());
        }

        private static List<List<Expression>> getSubPredicates(LogicalBinaryExpression expression) {
            return (List)ExpressionUtils.extractPredicates(expression.getType(), (Expression)expression).stream().map(predicate -> predicate instanceof LogicalBinaryExpression ? ExpressionUtils.extractPredicates((LogicalBinaryExpression)predicate) : ImmutableList.of((Object)predicate)).collect(ImmutableCollectors.toImmutableList());
        }

        private static Expression distributeIfPossible(LogicalBinaryExpression expression) {
            int newBaseExpressions;
            if (!DeterminismEvaluator.isDeterministic((Expression)expression)) {
                return expression;
            }
            List subPredicates = ExtractCommonPredicatesExpressionRewriter.getSubPredicates(expression).stream().map(ImmutableSet::copyOf).collect(Collectors.toList());
            int originalBaseExpressions = subPredicates.stream().mapToInt(Set::size).sum();
            try {
                newBaseExpressions = Math.multiplyExact(subPredicates.stream().mapToInt(Set::size).reduce(Math::multiplyExact).getAsInt(), subPredicates.size());
            }
            catch (ArithmeticException e) {
                return expression;
            }
            if (newBaseExpressions > originalBaseExpressions * 2) {
                return expression;
            }
            Set crossProduct = Sets.cartesianProduct(subPredicates);
            return ExpressionUtils.combinePredicates(expression.getType().flip(), (Collection)crossProduct.stream().map(expressions -> ExpressionUtils.combinePredicates(expression.getType(), expressions)).collect(ImmutableCollectors.toImmutableList()));
        }

        private static Set<Expression> filterDeterministicPredicates(List<Expression> predicates) {
            return predicates.stream().filter(DeterminismEvaluator::isDeterministic).collect(Collectors.toSet());
        }

        private static <T> List<T> removeAll(Collection<T> collection, Collection<T> elementsToRemove) {
            return (List)collection.stream().filter(element -> !elementsToRemove.contains(element)).collect(ImmutableCollectors.toImmutableList());
        }

        private static /* synthetic */ List lambda$extractCommonPredicates$1(Set commonPredicates, List predicateList) {
            return ExtractCommonPredicatesExpressionRewriter.removeAll(predicateList, commonPredicates);
        }
    }

    private static enum NodeContext {
        ROOT_NODE,
        NOT_ROOT_NODE;


        boolean isRootNode() {
            return this == ROOT_NODE;
        }
    }

    private static class PushDownNegationsExpressionRewriter
    extends ExpressionRewriter<Void> {
        private PushDownNegationsExpressionRewriter() {
        }

        public Expression rewriteNotExpression(NotExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
            if (node.getValue() instanceof LogicalBinaryExpression) {
                LogicalBinaryExpression child = (LogicalBinaryExpression)node.getValue();
                List<Expression> predicates = ExpressionUtils.extractPredicates(child);
                List negatedPredicates = (List)predicates.stream().map(predicate -> treeRewriter.rewrite((Expression)new NotExpression(predicate), (Object)context)).collect(ImmutableCollectors.toImmutableList());
                return ExpressionUtils.combinePredicates(child.getType().flip(), negatedPredicates);
            }
            if (node.getValue() instanceof ComparisonExpression && ((ComparisonExpression)node.getValue()).getType() != ComparisonExpressionType.IS_DISTINCT_FROM) {
                ComparisonExpression child = (ComparisonExpression)node.getValue();
                return new ComparisonExpression(child.getType().negate(), treeRewriter.rewrite(child.getLeft(), (Object)context), treeRewriter.rewrite(child.getRight(), (Object)context));
            }
            if (node.getValue() instanceof NotExpression) {
                NotExpression child = (NotExpression)node.getValue();
                return treeRewriter.rewrite(child.getValue(), (Object)context);
            }
            return new NotExpression(treeRewriter.rewrite(node.getValue(), (Object)context));
        }
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private final Metadata metadata;
        private final SqlParser sqlParser;
        private final Session session;
        private final Map<Symbol, Type> types;
        private final PlanNodeIdAllocator idAllocator;

        public Rewriter(Metadata metadata, SqlParser sqlParser, Session session, Map<Symbol, Type> types, PlanNodeIdAllocator idAllocator) {
            this.metadata = metadata;
            this.sqlParser = sqlParser;
            this.session = session;
            this.types = types;
            this.idAllocator = idAllocator;
        }

        @Override
        public PlanNode visitProject(ProjectNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            PlanNode source = context.rewrite(node.getSource());
            return new ProjectNode(node.getId(), source, node.getAssignments().rewrite(this::simplifyExpression));
        }

        @Override
        public PlanNode visitFilter(FilterNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            PlanNode source = context.rewrite(node.getSource());
            Expression simplified = this.simplifyExpression(node.getPredicate());
            if (simplified.equals((Object)BooleanLiteral.TRUE_LITERAL)) {
                return source;
            }
            if (simplified.equals((Object)BooleanLiteral.FALSE_LITERAL) || simplified instanceof NullLiteral) {
                return new ValuesNode(this.idAllocator.getNextId(), node.getOutputSymbols(), (List<List<Expression>>)ImmutableList.of());
            }
            return new FilterNode(node.getId(), source, simplified);
        }

        @Override
        public PlanNode visitTableScan(TableScanNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            Expression originalConstraint = null;
            if (node.getOriginalConstraint() != null) {
                originalConstraint = this.simplifyExpression(node.getOriginalConstraint());
            }
            return new TableScanNode(node.getId(), node.getTable(), node.getOutputSymbols(), node.getAssignments(), node.getLayout(), node.getCurrentConstraint(), originalConstraint);
        }

        private Expression simplifyExpression(Expression expression) {
            if (expression instanceof SymbolReference) {
                return expression;
            }
            expression = ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)new PushDownNegationsExpressionRewriter(), (Expression)expression);
            expression = ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)new ExtractCommonPredicatesExpressionRewriter(), (Expression)expression, (Object)((Object)NodeContext.ROOT_NODE));
            IdentityHashMap<Expression, Type> expressionTypes = ExpressionAnalyzer.getExpressionTypes(this.session, this.metadata, this.sqlParser, this.types, expression, Collections.emptyList());
            ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(expression, this.metadata, this.session, expressionTypes);
            return LiteralInterpreter.toExpression(interpreter.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(expression));
        }
    }
}

