/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.ExpressionExtractor;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SortNode;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.facebook.presto.sql.planner.plan.UnnestNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

public class PushDownDereferences {
    private final Metadata metadata;

    public PushDownDereferences(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of((Object)new ExtractFromFilter(), (Object)new ExtractFromJoin(), new PushDownDereferenceThrough<AssignUniqueId>(AssignUniqueId.class), new PushDownDereferenceThrough<WindowNode>(WindowNode.class), new PushDownDereferenceThrough<TopNNode>(TopNNode.class), new PushDownDereferenceThrough<RowNumberNode>(RowNumberNode.class), (Object[])new Rule[]{new PushDownDereferenceThrough<TopNRowNumberNode>(TopNRowNumberNode.class), new PushDownDereferenceThrough<SortNode>(SortNode.class), new PushDownDereferenceThrough<FilterNode>(FilterNode.class), new PushDownDereferenceThrough<LimitNode>(LimitNode.class), new PushDownDereferenceThroughProject(), new PushDownDereferenceThroughUnnest(), new PushDownDereferenceThroughSemiJoin(), new PushDownDereferenceThroughJoin()});
    }

    private RowExpression replaceDereferences(RowExpression rowExpression, Map<SpecialFormExpression, VariableReferenceExpression> dereferences) {
        return RowExpressionTreeRewriter.rewriteWith((RowExpressionRewriter)new DereferenceReplacer(dereferences), (RowExpression)rowExpression);
    }

    private static PlanNode createProject(PlanNode planNode, Assignments dereferences, PlanNodeIdAllocator idAllocator) {
        if (dereferences.isEmpty()) {
            return planNode;
        }
        Assignments assignments = Assignments.builder().putAll(AssignmentUtils.identityAssignments(planNode.getOutputVariables())).putAll(dereferences).build();
        return new ProjectNode(idAllocator.getNextId(), planNode, assignments);
    }

    private static List<SpecialFormExpression> extractDereference(RowExpression expression) {
        ImmutableList.Builder builder = ImmutableList.builder();
        expression.accept((RowExpressionVisitor)new DefaultRowExpressionTraversalVisitor<ImmutableList.Builder<SpecialFormExpression>>(){

            public Void visitSpecialForm(SpecialFormExpression node, ImmutableList.Builder<SpecialFormExpression> context) {
                if (PushDownDereferences.isValidDereference(node)) {
                    context.add((Object)node);
                } else {
                    node.getArguments().forEach(argument -> {
                        Void cfr_ignored_0 = (Void)argument.accept((RowExpressionVisitor)this, (Object)context);
                    });
                }
                return null;
            }
        }, (Object)builder);
        return builder.build();
    }

    private static Map<SpecialFormExpression, VariableReferenceExpression> getDereferenceSymbolMap(Collection<RowExpression> expressions, Rule.Context context, Metadata metadata) {
        Set dereferences = (Set)expressions.stream().flatMap(expression -> PushDownDereferences.extractDereference(expression).stream()).collect(ImmutableSet.toImmutableSet());
        if (dereferences.stream().anyMatch(expression -> PushDownDereferences.baseExists(expression, dereferences))) {
            return ImmutableMap.of();
        }
        return (Map)dereferences.stream().collect(ImmutableMap.toImmutableMap(Function.identity(), expression -> PushDownDereferences.createVariable(expression, context)));
    }

    private static VariableReferenceExpression createVariable(SpecialFormExpression expression, Rule.Context context) {
        return context.getVariableAllocator().newVariable((RowExpression)expression);
    }

    private static boolean baseExists(SpecialFormExpression expression, Set<SpecialFormExpression> dereferences) {
        RowExpression base = (RowExpression)expression.getArguments().get(0);
        while (base instanceof SpecialFormExpression) {
            if (dereferences.contains(base)) {
                return true;
            }
            base = (RowExpression)((SpecialFormExpression)base).getArguments().get(0);
        }
        return false;
    }

    private static boolean isValidDereference(SpecialFormExpression dereference) {
        SpecialFormExpression expression = dereference;
        while (!(expression instanceof VariableReferenceExpression)) {
            if (!(expression instanceof SpecialFormExpression) || expression.getForm() != SpecialFormExpression.Form.DEREFERENCE) {
                return false;
            }
            expression = (RowExpression)expression.getArguments().get(0);
        }
        return true;
    }

    private static VariableReferenceExpression getBase(RowExpression expression) {
        return (VariableReferenceExpression)Iterables.getOnlyElement(VariablesExtractor.extractAll(expression));
    }

    private static class DereferenceReplacer
    extends RowExpressionRewriter<Void> {
        private final Map<SpecialFormExpression, VariableReferenceExpression> expressions;

        DereferenceReplacer(Map<SpecialFormExpression, VariableReferenceExpression> expressions) {
            this.expressions = Objects.requireNonNull(expressions, "expressions is null");
        }

        public RowExpression rewriteSpecialForm(SpecialFormExpression node, Void context, RowExpressionTreeRewriter<Void> treeRewriter) {
            if (this.expressions.containsKey(node)) {
                return new VariableReferenceExpression(this.expressions.get(node).getName(), node.getType());
            }
            return treeRewriter.defaultRewrite((RowExpression)node, (Object)context);
        }
    }

    public class PushDownDereferenceThroughUnnest
    extends PushdownDereferencesInProject<UnnestNode> {
        PushDownDereferenceThroughUnnest() {
            super(Patterns.unnest());
        }

        @Override
        protected Rule.Result pushDownDereferences(Rule.Context context, UnnestNode unnestNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions) {
            Map dereferencesMap = (Map)expressions.inverse().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
            Assignments assignments = Assignments.builder().putAll(AssignmentUtils.identityAssignments(unnestNode.getSource().getOutputVariables())).putAll(dereferencesMap).build();
            ProjectNode source = new ProjectNode(context.getIdAllocator().getNextId(), unnestNode.getSource(), assignments);
            return Rule.Result.ofPlanNode(new UnnestNode(context.getIdAllocator().getNextId(), (PlanNode)source, (List<VariableReferenceExpression>)ImmutableList.builder().addAll(unnestNode.getReplicateVariables()).addAll((Iterable)expressions.values()).build(), unnestNode.getUnnestVariables(), unnestNode.getOrdinalityVariable()));
        }
    }

    public class PushDownDereferenceThroughProject
    extends PushdownDereferencesInProject<ProjectNode> {
        PushDownDereferenceThroughProject() {
            super(Patterns.project());
        }

        @Override
        protected Rule.Result pushDownDereferences(Rule.Context context, ProjectNode projectNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions) {
            Map dereferencesMap = (Map)expressions.inverse().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
            return Rule.Result.ofPlanNode((PlanNode)new ProjectNode(context.getIdAllocator().getNextId(), projectNode.getSource(), Assignments.builder().putAll(projectNode.getAssignments()).putAll(dereferencesMap).build()));
        }
    }

    public class PushDownDereferenceThroughSemiJoin
    extends PushdownDereferencesInProject<SemiJoinNode> {
        PushDownDereferenceThroughSemiJoin() {
            super(Patterns.semiJoin());
        }

        @Override
        protected Rule.Result pushDownDereferences(Rule.Context context, SemiJoinNode semiJoinNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions) {
            Assignments.Builder filteringSourceDereferences = Assignments.builder();
            Assignments.Builder sourceDereferences = Assignments.builder();
            for (Map.Entry entry : expressions.inverse().entrySet()) {
                VariableReferenceExpression baseVariable = PushDownDereferences.getBase((RowExpression)entry.getValue());
                if (semiJoinNode.getFilteringSource().getOutputVariables().contains(baseVariable)) {
                    filteringSourceDereferences.put((VariableReferenceExpression)entry.getKey(), (RowExpression)entry.getValue());
                    continue;
                }
                sourceDereferences.put((VariableReferenceExpression)entry.getKey(), (RowExpression)entry.getValue());
            }
            PlanNode filteringSource = PushDownDereferences.createProject(semiJoinNode.getFilteringSource(), filteringSourceDereferences.build(), context.getIdAllocator());
            PlanNode source = PushDownDereferences.createProject(semiJoinNode.getSource(), sourceDereferences.build(), context.getIdAllocator());
            return Rule.Result.ofPlanNode(semiJoinNode.replaceChildren((List<PlanNode>)ImmutableList.of((Object)source, (Object)filteringSource)));
        }
    }

    public class PushDownDereferenceThroughJoin
    extends PushdownDereferencesInProject<JoinNode> {
        PushDownDereferenceThroughJoin() {
            super(Patterns.join());
        }

        @Override
        protected Rule.Result pushDownDereferences(Rule.Context context, JoinNode joinNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions) {
            Assignments.Builder leftSideDereferences = Assignments.builder();
            Assignments.Builder rightSideDereferences = Assignments.builder();
            for (Map.Entry entry : expressions.inverse().entrySet()) {
                VariableReferenceExpression baseVariable = PushDownDereferences.getBase((RowExpression)entry.getValue());
                if (joinNode.getLeft().getOutputVariables().contains(baseVariable)) {
                    leftSideDereferences.put((VariableReferenceExpression)entry.getKey(), (RowExpression)entry.getValue());
                    continue;
                }
                rightSideDereferences.put((VariableReferenceExpression)entry.getKey(), (RowExpression)entry.getValue());
            }
            PlanNode leftNode = PushDownDereferences.createProject(joinNode.getLeft(), leftSideDereferences.build(), context.getIdAllocator());
            PlanNode rightNode = PushDownDereferences.createProject(joinNode.getRight(), rightSideDereferences.build(), context.getIdAllocator());
            return Rule.Result.ofPlanNode(new JoinNode(context.getIdAllocator().getNextId(), joinNode.getType(), leftNode, rightNode, joinNode.getCriteria(), (List<VariableReferenceExpression>)ImmutableList.builder().addAll((Iterable)leftNode.getOutputVariables()).addAll((Iterable)rightNode.getOutputVariables()).build(), joinNode.getFilter().map(expression -> PushDownDereferences.this.replaceDereferences(expression, (Map)expressions)), joinNode.getLeftHashVariable(), joinNode.getRightHashVariable(), joinNode.getDistributionType(), joinNode.getDynamicFilters()));
        }
    }

    public class PushDownDereferenceThrough<N extends PlanNode>
    extends PushdownDereferencesInProject<N> {
        public PushDownDereferenceThrough(Class<N> planNodeClass) {
            super(Pattern.typeOf(planNodeClass));
        }

        @Override
        protected Rule.Result pushDownDereferences(Rule.Context context, N targetNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions) {
            PlanNode source = (PlanNode)Iterables.getOnlyElement((Iterable)targetNode.getSources());
            Map dereferencesMap = (Map)expressions.inverse().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
            ProjectNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putAll(AssignmentUtils.identityAssignments(source.getOutputVariables())).putAll(dereferencesMap).build());
            return Rule.Result.ofPlanNode(targetNode.replaceChildren((List)ImmutableList.of((Object)projectNode)));
        }
    }

    private abstract class PushdownDereferencesInProject<N extends PlanNode>
    implements Rule<ProjectNode> {
        private final Capture<N> targetCapture = Capture.newCapture();
        private final Pattern<N> targetPattern;

        protected PushdownDereferencesInProject(Pattern<N> targetPattern) {
            this.targetPattern = Objects.requireNonNull(targetPattern, "targetPattern is null");
        }

        @Override
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.isPushdownDereferenceEnabled(session);
        }

        @Override
        public Pattern<ProjectNode> getPattern() {
            return Patterns.project().with(Patterns.source().matching(this.targetPattern.capturedAs(this.targetCapture)));
        }

        @Override
        public Rule.Result apply(ProjectNode node, Captures captures, Rule.Context context) {
            PlanNode child = (PlanNode)captures.get(this.targetCapture);
            Map allDereferencesInProject = PushDownDereferences.getDereferenceSymbolMap(node.getAssignments().getExpressions(), context, PushDownDereferences.this.metadata);
            Set childSourceVariables = (Set)child.getSources().stream().map(PlanNode::getOutputVariables).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet());
            Map pushdownDereferences = (Map)allDereferencesInProject.entrySet().stream().filter(entry -> childSourceVariables.contains(PushDownDereferences.getBase((RowExpression)entry.getKey()))).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
            if (pushdownDereferences.isEmpty()) {
                return Rule.Result.empty();
            }
            Rule.Result result = this.pushDownDereferences(context, child, (BiMap<SpecialFormExpression, VariableReferenceExpression>)HashBiMap.create((Map)pushdownDereferences));
            if (result.isEmpty()) {
                return Rule.Result.empty();
            }
            Assignments.Builder builder = Assignments.builder();
            for (Map.Entry entry2 : node.getAssignments().entrySet()) {
                builder.put((VariableReferenceExpression)entry2.getKey(), PushDownDereferences.this.replaceDereferences((RowExpression)entry2.getValue(), pushdownDereferences));
            }
            return Rule.Result.ofPlanNode((PlanNode)new ProjectNode(context.getIdAllocator().getNextId(), result.getTransformedPlan().get(), builder.build()));
        }

        protected abstract Rule.Result pushDownDereferences(Rule.Context var1, N var2, BiMap<SpecialFormExpression, VariableReferenceExpression> var3);
    }

    class ExtractFromJoin
    extends ExtractProjectDereferences<JoinNode> {
        ExtractFromJoin() {
            super(JoinNode.class);
        }

        @Override
        protected JoinNode rewrite(Rule.Context context, JoinNode joinNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions) {
            Assignments.Builder leftSideDereferences = Assignments.builder();
            Assignments.Builder rightSideDereferences = Assignments.builder();
            for (Map.Entry entry : expressions.inverse().entrySet()) {
                VariableReferenceExpression baseVariable = PushDownDereferences.getBase((RowExpression)entry.getValue());
                if (joinNode.getLeft().getOutputVariables().contains(baseVariable)) {
                    leftSideDereferences.put((VariableReferenceExpression)entry.getKey(), (RowExpression)entry.getValue());
                    continue;
                }
                rightSideDereferences.put((VariableReferenceExpression)entry.getKey(), (RowExpression)entry.getValue());
            }
            PlanNode leftNode = PushDownDereferences.createProject(joinNode.getLeft(), leftSideDereferences.build(), context.getIdAllocator());
            PlanNode rightNode = PushDownDereferences.createProject(joinNode.getRight(), rightSideDereferences.build(), context.getIdAllocator());
            return new JoinNode(context.getIdAllocator().getNextId(), joinNode.getType(), leftNode, rightNode, joinNode.getCriteria(), (List<VariableReferenceExpression>)ImmutableList.builder().addAll((Iterable)leftNode.getOutputVariables()).addAll((Iterable)rightNode.getOutputVariables()).build(), joinNode.getFilter().map(expression -> PushDownDereferences.this.replaceDereferences(expression, (Map)expressions)), joinNode.getLeftHashVariable(), joinNode.getRightHashVariable(), joinNode.getDistributionType(), joinNode.getDynamicFilters());
        }
    }

    class ExtractFromFilter
    extends ExtractProjectDereferences<FilterNode> {
        ExtractFromFilter() {
            super(FilterNode.class);
        }

        @Override
        protected FilterNode rewrite(Rule.Context context, FilterNode node, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions) {
            PlanNode source = node.getSource();
            Map dereferencesMap = (Map)expressions.inverse().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
            Assignments assignments = Assignments.builder().putAll(AssignmentUtils.identityAssignments(source.getOutputVariables())).putAll(dereferencesMap).build();
            ProjectNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), source, assignments);
            return new FilterNode(context.getIdAllocator().getNextId(), (PlanNode)projectNode, PushDownDereferences.this.replaceDereferences(node.getPredicate(), expressions));
        }
    }

    abstract class ExtractProjectDereferences<N extends PlanNode>
    implements Rule<N> {
        private final Class<N> planNodeClass;

        ExtractProjectDereferences(Class<N> planNodeClass) {
            this.planNodeClass = planNodeClass;
        }

        @Override
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.isPushdownDereferenceEnabled(session);
        }

        @Override
        public Pattern<N> getPattern() {
            return Pattern.typeOf(this.planNodeClass);
        }

        @Override
        public Rule.Result apply(N node, Captures captures, Rule.Context context) {
            Map expressions = PushDownDereferences.getDereferenceSymbolMap(ExpressionExtractor.extractExpressionsNonRecursive(node), context, PushDownDereferences.this.metadata);
            if (expressions.isEmpty()) {
                return Rule.Result.empty();
            }
            return Rule.Result.ofPlanNode((PlanNode)new ProjectNode(context.getIdAllocator().getNextId(), this.rewrite(context, node, (BiMap<SpecialFormExpression, VariableReferenceExpression>)HashBiMap.create((Map)expressions)), AssignmentUtils.identityAssignments(node.getOutputVariables())));
        }

        protected abstract N rewrite(Rule.Context var1, N var2, BiMap<SpecialFormExpression, VariableReferenceExpression> var3);
    }
}

