/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.SourceLocation;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.PlanVisitor;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InSubqueryExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import com.google.common.graph.Traverser;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class TransformCorrelatedInPredicateToJoin
implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode().with(Pattern.nonEmpty(Patterns.Apply.correlation()));
    private final FunctionResolution functionResolution;

    public TransformCorrelatedInPredicateToJoin(FunctionAndTypeManager functionAndTypeManager) {
        Objects.requireNonNull(functionAndTypeManager, "functionManager is null");
        this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
    }

    @Override
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(ApplyNode apply, Captures captures, Rule.Context context) {
        Assignments subqueryAssignments = apply.getSubqueryAssignments();
        if (subqueryAssignments.size() != 1) {
            return Rule.Result.empty();
        }
        RowExpression assignmentExpression = (RowExpression)Iterables.getOnlyElement((Iterable)subqueryAssignments.getExpressions());
        if (!(assignmentExpression instanceof InSubqueryExpression)) {
            return Rule.Result.empty();
        }
        InSubqueryExpression inPredicate = (InSubqueryExpression)assignmentExpression;
        VariableReferenceExpression inPredicateOutputVariable = (VariableReferenceExpression)Iterables.getOnlyElement((Iterable)subqueryAssignments.getVariables());
        return this.apply(apply, inPredicate, inPredicateOutputVariable, context.getLookup(), context.getIdAllocator(), context.getVariableAllocator());
    }

    private Rule.Result apply(ApplyNode apply, InSubqueryExpression inPredicate, VariableReferenceExpression inPredicateOutputVariable, Lookup lookup, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator) {
        Optional<Decorrelated> decorrelated = new DecorrelatingVisitor(lookup, apply.getCorrelation(), TypeProvider.viewOf(variableAllocator.getVariables())).decorrelate(apply.getSubquery());
        if (!decorrelated.isPresent()) {
            return Rule.Result.empty();
        }
        PlanNode projection = this.buildInPredicateEquivalent(apply, inPredicate, inPredicateOutputVariable, decorrelated.get(), idAllocator, variableAllocator);
        return Rule.Result.ofPlanNode(projection);
    }

    private PlanNode buildInPredicateEquivalent(ApplyNode apply, InSubqueryExpression inPredicate, VariableReferenceExpression inPredicateOutputVariable, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator) {
        RowExpression correlationCondition = LogicalRowExpressions.and(decorrelated.getCorrelatedPredicates());
        PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode();
        AssignUniqueId probeSide = new AssignUniqueId(apply.getSourceLocation(), idAllocator.getNextId(), apply.getInput(), variableAllocator.newVariable("unique", (Type)BigintType.BIGINT));
        VariableReferenceExpression buildSideKnownNonNull = variableAllocator.newVariable(inPredicateOutputVariable.getSourceLocation(), "buildSideKnownNonNull", (Type)BigintType.BIGINT);
        ProjectNode buildSide = new ProjectNode(idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder().putAll(AssignmentUtils.identityAssignments(decorrelatedBuildSource.getOutputVariables())).put(buildSideKnownNonNull, (RowExpression)Expressions.constant(0L, (Type)BigintType.BIGINT)).build());
        VariableReferenceExpression probeSideSymbolReference = inPredicate.getValue();
        VariableReferenceExpression buildSideSymbolReference = inPredicate.getSubquery();
        SpecialFormExpression isProbeSideNull = Expressions.specialForm((Optional<SourceLocation>)probeSideSymbolReference.getSourceLocation(), SpecialFormExpression.Form.IS_NULL, (Type)BooleanType.BOOLEAN, new RowExpression[]{probeSideSymbolReference});
        SpecialFormExpression isBuildSideNull = Expressions.specialForm((Optional<SourceLocation>)buildSideSymbolReference.getSourceLocation(), SpecialFormExpression.Form.IS_NULL, (Type)BooleanType.BOOLEAN, new RowExpression[]{buildSideSymbolReference});
        CallExpression comparison = Expressions.call(ComparisonExpression.Operator.EQUAL.name(), this.functionResolution.comparisonFunction(ComparisonExpression.Operator.EQUAL, probeSideSymbolReference.getType(), buildSideSymbolReference.getType()), (Type)BooleanType.BOOLEAN, new RowExpression[]{probeSideSymbolReference, buildSideSymbolReference});
        RowExpression joinExpression = LogicalRowExpressions.and((RowExpression[])new RowExpression[]{LogicalRowExpressions.or((RowExpression[])new RowExpression[]{isProbeSideNull, comparison, isBuildSideNull}), correlationCondition});
        JoinNode leftOuterJoin = TransformCorrelatedInPredicateToJoin.leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression);
        VariableReferenceExpression countMatchesVariable = variableAllocator.newVariable(buildSideSymbolReference.getSourceLocation(), "countMatches", (Type)BigintType.BIGINT);
        VariableReferenceExpression countNullMatchesVariable = variableAllocator.newVariable(buildSideSymbolReference.getSourceLocation(), "countNullMatches", (Type)BigintType.BIGINT);
        RowExpression matchCondition = LogicalRowExpressions.and((RowExpression[])new RowExpression[]{this.isNotNull((RowExpression)probeSideSymbolReference), this.isNotNull((RowExpression)buildSideSymbolReference)});
        RowExpression nullMatchCondition = LogicalRowExpressions.and((RowExpression[])new RowExpression[]{this.isNotNull((RowExpression)buildSideKnownNonNull), this.not(matchCondition)});
        AggregationNode aggregation = new AggregationNode(apply.getSourceLocation(), idAllocator.getNextId(), (PlanNode)leftOuterJoin, (Map)ImmutableMap.builder().put((Object)countMatchesVariable, (Object)this.countWithFilter(matchCondition)).put((Object)countNullMatchesVariable, (Object)this.countWithFilter(nullMatchCondition)).build(), AggregationNode.singleGroupingSet(probeSide.getOutputVariables()), (List)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), Optional.empty());
        RowExpression inPredicateEquivalent = Expressions.searchedCaseExpression((List<RowExpression>)ImmutableList.of((Object)Expressions.specialForm(SpecialFormExpression.Form.WHEN, (Type)BooleanType.BOOLEAN, new RowExpression[]{this.isGreaterThan(countMatchesVariable, 0L), LogicalRowExpressions.TRUE_CONSTANT}), (Object)Expressions.specialForm(SpecialFormExpression.Form.WHEN, (Type)BooleanType.BOOLEAN, new RowExpression[]{this.isGreaterThan(countNullMatchesVariable, 0L), new ConstantExpression(null, (Type)BooleanType.BOOLEAN)})), Optional.of(LogicalRowExpressions.FALSE_CONSTANT));
        return new ProjectNode(idAllocator.getNextId(), (PlanNode)aggregation, Assignments.builder().putAll(AssignmentUtils.identityAssignments(apply.getInput().getOutputVariables())).put(inPredicateOutputVariable, inPredicateEquivalent).build());
    }

    private RowExpression isNotNull(RowExpression expression) {
        return this.not((RowExpression)Expressions.specialForm(SpecialFormExpression.Form.IS_NULL, (Type)BooleanType.BOOLEAN, (List<RowExpression>)ImmutableList.of((Object)expression)));
    }

    private RowExpression not(RowExpression expression) {
        return Expressions.call((Optional<SourceLocation>)expression.getSourceLocation(), "not", this.functionResolution.notFunction(), (Type)BooleanType.BOOLEAN, expression);
    }

    private static JoinNode leftOuterJoin(PlanNodeIdAllocator idAllocator, AssignUniqueId probeSide, ProjectNode buildSide, RowExpression joinExpression) {
        return new JoinNode(probeSide.getSourceLocation(), idAllocator.getNextId(), JoinType.LEFT, (PlanNode)probeSide, (PlanNode)buildSide, (List)ImmutableList.of(), (List)ImmutableList.builder().addAll(probeSide.getOutputVariables()).addAll((Iterable)buildSide.getOutputVariables()).build(), Optional.of(joinExpression), Optional.empty(), Optional.empty(), Optional.empty(), (Map)ImmutableMap.of());
    }

    private AggregationNode.Aggregation countWithFilter(RowExpression condition) {
        return new AggregationNode.Aggregation(new CallExpression(condition.getSourceLocation(), "count", this.functionResolution.countFunction(), (Type)BigintType.BIGINT, (List)ImmutableList.of()), Optional.of(condition), Optional.empty(), false, Optional.empty());
    }

    private RowExpression isGreaterThan(VariableReferenceExpression variable, long value) {
        return Expressions.call(ComparisonExpression.Operator.GREATER_THAN.name(), this.functionResolution.comparisonFunction(ComparisonExpression.Operator.GREATER_THAN, (Type)BigintType.BIGINT, (Type)BigintType.BIGINT), (Type)BooleanType.BOOLEAN, new RowExpression[]{variable, Expressions.constant(value, (Type)BigintType.BIGINT)});
    }

    private static class DecorrelatingVisitor
    extends InternalPlanVisitor<Optional<Decorrelated>, PlanNode> {
        private final Lookup lookup;
        private final Set<VariableReferenceExpression> correlation;
        private final TypeProvider types;

        public DecorrelatingVisitor(Lookup lookup, Iterable<VariableReferenceExpression> correlation, TypeProvider types) {
            this.lookup = Objects.requireNonNull(lookup, "lookup is null");
            this.correlation = ImmutableSet.copyOf(Objects.requireNonNull(correlation, "correlation is null"));
            this.types = Objects.requireNonNull(types, "types is null");
        }

        public Optional<Decorrelated> decorrelate(PlanNode reference) {
            return (Optional)this.lookup.resolve(reference).accept((PlanVisitor)this, (Object)reference);
        }

        public Optional<Decorrelated> visitProject(ProjectNode node, PlanNode reference) {
            if (this.isCorrelatedShallowly((PlanNode)node)) {
                return Optional.empty();
            }
            Optional<Decorrelated> result = this.decorrelate(node.getSource());
            return result.map(decorrelated -> {
                Assignments.Builder assignments = Assignments.builder().putAll(node.getAssignments());
                decorrelated.getCorrelatedPredicates().stream().flatMap(expression -> Streams.stream((Iterable)Traverser.forTree(RowExpression::getChildren).depthFirstPreOrder(expression))).filter(VariableReferenceExpression.class::isInstance).map(VariableReferenceExpression.class::cast).filter(variable -> !this.correlation.contains(variable)).map(xva$0 -> AssignmentUtils.identityAssignments(xva$0)).forEach(arg_0 -> ((Assignments.Builder)assignments).putAll(arg_0));
                return new Decorrelated(decorrelated.getCorrelatedPredicates(), (PlanNode)new ProjectNode(node.getId(), decorrelated.getDecorrelatedNode(), assignments.build()));
            });
        }

        public Optional<Decorrelated> visitFilter(FilterNode node, PlanNode reference) {
            Optional<Decorrelated> result = this.decorrelate(node.getSource());
            return result.map(decorrelated -> new Decorrelated((List<RowExpression>)ImmutableList.builder().addAll(decorrelated.getCorrelatedPredicates()).add((Object)node.getPredicate()).build(), decorrelated.getDecorrelatedNode()));
        }

        public Optional<Decorrelated> visitPlan(PlanNode node, PlanNode reference) {
            if (this.isCorrelatedRecursively(node)) {
                return Optional.empty();
            }
            return Optional.of(new Decorrelated((List<RowExpression>)ImmutableList.of(), reference));
        }

        private boolean isCorrelatedRecursively(PlanNode node) {
            if (this.isCorrelatedShallowly(node)) {
                return true;
            }
            return node.getSources().stream().map(this.lookup::resolve).anyMatch(this::isCorrelatedRecursively);
        }

        private boolean isCorrelatedShallowly(PlanNode node) {
            return VariablesExtractor.extractUniqueNonRecursive(node).stream().anyMatch(this.correlation::contains);
        }
    }

    private static class Decorrelated {
        private final List<RowExpression> correlatedPredicates;
        private final PlanNode decorrelatedNode;

        public Decorrelated(List<RowExpression> correlatedPredicates, PlanNode decorrelatedNode) {
            this.correlatedPredicates = ImmutableList.copyOf((Collection)Objects.requireNonNull(correlatedPredicates, "correlatedPredicates is null"));
            this.decorrelatedNode = Objects.requireNonNull(decorrelatedNode, "decorrelatedNode is null");
        }

        public List<RowExpression> getCorrelatedPredicates() {
            return this.correlatedPredicates;
        }

        public PlanNode getDecorrelatedNode() {
            return this.decorrelatedNode;
        }
    }
}

