/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.SymbolsExtractor;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.sql.tree.WhenClause;
import com.facebook.presto.sql.util.AstUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import javax.annotation.Nullable;

public class TransformCorrelatedInPredicateToJoin
implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode().with(Pattern.nonEmpty(Patterns.Apply.correlation()));

    @Override
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(ApplyNode apply, Captures captures, Rule.Context context) {
        Assignments subqueryAssignments = apply.getSubqueryAssignments();
        if (subqueryAssignments.size() != 1) {
            return Rule.Result.empty();
        }
        Expression assignmentExpression = (Expression)Iterables.getOnlyElement(subqueryAssignments.getExpressions());
        if (!(assignmentExpression instanceof InPredicate)) {
            return Rule.Result.empty();
        }
        InPredicate inPredicate = (InPredicate)assignmentExpression;
        Symbol inPredicateOutputSymbol = (Symbol)Iterables.getOnlyElement(subqueryAssignments.getSymbols());
        return this.apply(apply, inPredicate, inPredicateOutputSymbol, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator());
    }

    private Rule.Result apply(ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
        Optional<Decorrelated> decorrelated = new DecorrelatingVisitor(lookup, apply.getCorrelation()).decorrelate(apply.getSubquery());
        if (!decorrelated.isPresent()) {
            return Rule.Result.empty();
        }
        PlanNode projection = this.buildInPredicateEquivalent(apply, inPredicate, inPredicateOutputSymbol, decorrelated.get(), idAllocator, symbolAllocator);
        return Rule.Result.ofPlanNode(projection);
    }

    private PlanNode buildInPredicateEquivalent(ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
        Expression correlationCondition = ExpressionUtils.and(decorrelated.getCorrelatedPredicates());
        PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode();
        AssignUniqueId probeSide = new AssignUniqueId(idAllocator.getNextId(), apply.getInput(), symbolAllocator.newSymbol("unique", (Type)BigintType.BIGINT));
        Symbol buildSideKnownNonNull = symbolAllocator.newSymbol("buildSideKnownNonNull", (Type)BigintType.BIGINT);
        ProjectNode buildSide = new ProjectNode(idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder().putIdentities(decorrelatedBuildSource.getOutputSymbols()).put(buildSideKnownNonNull, TransformCorrelatedInPredicateToJoin.bigint(0L)).build());
        Symbol probeSideSymbol = Symbol.from(inPredicate.getValue());
        Symbol buildSideSymbol = Symbol.from(inPredicate.getValueList());
        Expression joinExpression = ExpressionUtils.and(ExpressionUtils.or(new Expression[]{new IsNullPredicate((Expression)probeSideSymbol.toSymbolReference()), new ComparisonExpression(ComparisonExpressionType.EQUAL, (Expression)probeSideSymbol.toSymbolReference(), (Expression)buildSideSymbol.toSymbolReference()), new IsNullPredicate((Expression)buildSideSymbol.toSymbolReference())}), correlationCondition);
        JoinNode leftOuterJoin = TransformCorrelatedInPredicateToJoin.leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression);
        Symbol countMatchesSymbol = symbolAllocator.newSymbol("countMatches", (Type)BigintType.BIGINT);
        Symbol countNullMatchesSymbol = symbolAllocator.newSymbol("countNullMatches", (Type)BigintType.BIGINT);
        Expression matchCondition = ExpressionUtils.and(TransformCorrelatedInPredicateToJoin.isNotNull(probeSideSymbol), TransformCorrelatedInPredicateToJoin.isNotNull(buildSideSymbol));
        Expression nullMatchCondition = ExpressionUtils.and(TransformCorrelatedInPredicateToJoin.isNotNull(buildSideKnownNonNull), TransformCorrelatedInPredicateToJoin.not(matchCondition));
        AggregationNode aggregation = new AggregationNode(idAllocator.getNextId(), leftOuterJoin, (Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.builder().put((Object)countMatchesSymbol, (Object)TransformCorrelatedInPredicateToJoin.countWithFilter(matchCondition)).put((Object)countNullMatchesSymbol, (Object)TransformCorrelatedInPredicateToJoin.countWithFilter(nullMatchCondition)).build(), (List<List<Symbol>>)ImmutableList.of(probeSide.getOutputSymbols()), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
        SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause(TransformCorrelatedInPredicateToJoin.isGreaterThan(countMatchesSymbol, 0L), TransformCorrelatedInPredicateToJoin.booleanConstant(true)), (Object)new WhenClause(TransformCorrelatedInPredicateToJoin.isGreaterThan(countNullMatchesSymbol, 0L), TransformCorrelatedInPredicateToJoin.booleanConstant(null))), Optional.of(TransformCorrelatedInPredicateToJoin.booleanConstant(false)));
        return new ProjectNode(idAllocator.getNextId(), aggregation, Assignments.builder().putIdentities(apply.getInput().getOutputSymbols()).put(inPredicateOutputSymbol, (Expression)inPredicateEquivalent).build());
    }

    private static JoinNode leftOuterJoin(PlanNodeIdAllocator idAllocator, AssignUniqueId probeSide, ProjectNode buildSide, Expression joinExpression) {
        return new JoinNode(idAllocator.getNextId(), JoinNode.Type.LEFT, probeSide, buildSide, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), (List<Symbol>)ImmutableList.builder().addAll(probeSide.getOutputSymbols()).addAll(buildSide.getOutputSymbols()).build(), Optional.of(joinExpression), Optional.empty(), Optional.empty(), Optional.empty());
    }

    private static AggregationNode.Aggregation countWithFilter(Expression condition) {
        FunctionCall countCall = new FunctionCall(QualifiedName.of((String)"count"), Optional.empty(), Optional.of(condition), Optional.empty(), false, (List)ImmutableList.of());
        return new AggregationNode.Aggregation(countCall, new Signature("count", FunctionKind.AGGREGATE, BigintType.BIGINT.getTypeSignature(), new TypeSignature[0]), Optional.empty());
    }

    private static Expression isGreaterThan(Symbol symbol, long value) {
        return new ComparisonExpression(ComparisonExpressionType.GREATER_THAN, (Expression)symbol.toSymbolReference(), TransformCorrelatedInPredicateToJoin.bigint(value));
    }

    private static Expression not(Expression booleanExpression) {
        return new NotExpression(booleanExpression);
    }

    private static Expression isNotNull(Symbol symbol) {
        return new IsNotNullPredicate((Expression)symbol.toSymbolReference());
    }

    private static Expression bigint(long value) {
        return new Cast((Expression)new LongLiteral(String.valueOf(value)), BigintType.BIGINT.toString());
    }

    private static Expression booleanConstant(@Nullable Boolean value) {
        if (value == null) {
            return new Cast((Expression)new NullLiteral(), BooleanType.BOOLEAN.toString());
        }
        return new BooleanLiteral(value.toString());
    }

    private static class Decorrelated {
        private final List<Expression> correlatedPredicates;
        private final PlanNode decorrelatedNode;

        public Decorrelated(List<Expression> correlatedPredicates, PlanNode decorrelatedNode) {
            this.correlatedPredicates = ImmutableList.copyOf((Collection)Objects.requireNonNull(correlatedPredicates, "correlatedPredicates is null"));
            this.decorrelatedNode = Objects.requireNonNull(decorrelatedNode, "decorrelatedNode is null");
        }

        public List<Expression> getCorrelatedPredicates() {
            return this.correlatedPredicates;
        }

        public PlanNode getDecorrelatedNode() {
            return this.decorrelatedNode;
        }
    }

    private static class DecorrelatingVisitor
    extends PlanVisitor<Optional<Decorrelated>, PlanNode> {
        private final Lookup lookup;
        private final Set<Symbol> correlation;

        public DecorrelatingVisitor(Lookup lookup, Iterable<Symbol> correlation) {
            this.lookup = Objects.requireNonNull(lookup, "lookup is null");
            this.correlation = ImmutableSet.copyOf(Objects.requireNonNull(correlation, "correlation is null"));
        }

        public Optional<Decorrelated> decorrelate(PlanNode reference) {
            return this.lookup.resolve(reference).accept(this, reference);
        }

        @Override
        public Optional<Decorrelated> visitProject(ProjectNode node, PlanNode reference) {
            if (this.isCorrelatedShallowly(node)) {
                return Optional.empty();
            }
            Optional<Decorrelated> result = this.decorrelate(node.getSource());
            return result.map(decorrelated -> {
                Assignments.Builder assignments = Assignments.builder().putAll(node.getAssignments());
                decorrelated.getCorrelatedPredicates().stream().flatMap(AstUtils::preOrder).filter(SymbolReference.class::isInstance).map(SymbolReference.class::cast).filter(symbolReference -> !this.correlation.contains(Symbol.from((Expression)symbolReference))).forEach(symbolReference -> assignments.putIdentity(Symbol.from((Expression)symbolReference)));
                return new Decorrelated(decorrelated.getCorrelatedPredicates(), new ProjectNode(node.getId(), decorrelated.getDecorrelatedNode(), assignments.build()));
            });
        }

        @Override
        public Optional<Decorrelated> visitFilter(FilterNode node, PlanNode reference) {
            Optional<Decorrelated> result = this.decorrelate(node.getSource());
            return result.map(decorrelated -> new Decorrelated((List<Expression>)ImmutableList.builder().addAll(decorrelated.getCorrelatedPredicates()).add((Object)node.getPredicate()).build(), decorrelated.getDecorrelatedNode()));
        }

        @Override
        protected Optional<Decorrelated> visitPlan(PlanNode node, PlanNode reference) {
            if (this.isCorrelatedRecursively(node)) {
                return Optional.empty();
            }
            return Optional.of(new Decorrelated((List<Expression>)ImmutableList.of(), reference));
        }

        private boolean isCorrelatedRecursively(PlanNode node) {
            if (this.isCorrelatedShallowly(node)) {
                return true;
            }
            return node.getSources().stream().map(this.lookup::resolve).anyMatch(this::isCorrelatedRecursively);
        }

        private boolean isCorrelatedShallowly(PlanNode node) {
            return SymbolsExtractor.extractUniqueNonRecursive(node).stream().anyMatch(this.correlation::contains);
        }
    }
}

