/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.IntermediateFormExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

public class AddNotNullFiltersToJoinNode
implements Rule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join();
    private final FunctionAndTypeManager functionAndTypeManager;
    private final Logger logger = Logger.get(AddNotNullFiltersToJoinNode.class);
    private final FunctionResolution functionResolution;

    public AddNotNullFiltersToJoinNode(FunctionAndTypeManager functionAndTypeManager) {
        this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
        this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.getNotNullInferenceStrategy(session) != FeaturesConfig.JoinNotNullInferenceStrategy.NONE;
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        Collection<VariableReferenceExpression> inferredNotNullVariables;
        FeaturesConfig.JoinNotNullInferenceStrategy notNullInferenceStrategy = SystemSessionProperties.getNotNullInferenceStrategy(context.getSession());
        switch (joinNode.getType()) {
            case LEFT: {
                inferredNotNullVariables = this.extractNotNullVariables(joinNode.getCriteria(), joinNode.getFilter(), joinNode.getRight().getOutputVariables(), notNullInferenceStrategy);
                break;
            }
            case RIGHT: {
                inferredNotNullVariables = this.extractNotNullVariables(joinNode.getCriteria(), joinNode.getFilter(), joinNode.getLeft().getOutputVariables(), notNullInferenceStrategy);
                break;
            }
            case INNER: {
                inferredNotNullVariables = this.extractNotNullVariables(joinNode.getCriteria(), joinNode.getFilter(), (List)Stream.concat(joinNode.getLeft().getOutputVariables().stream(), joinNode.getRight().getOutputVariables().stream()).collect(ImmutableList.toImmutableList()), notNullInferenceStrategy);
                break;
            }
            default: {
                return Rule.Result.empty();
            }
        }
        if (inferredNotNullVariables.isEmpty()) {
            return Rule.Result.empty();
        }
        Set<VariableReferenceExpression> existingNotNullVariables = this.getExistingNotNullVariables(joinNode.getFilter());
        this.logger.debug("NotNull filters :: Existing : %s, Inferred :%s", new Object[]{existingNotNullVariables, inferredNotNullVariables});
        if (existingNotNullVariables.containsAll(inferredNotNullVariables)) {
            return Rule.Result.empty();
        }
        RowExpression updatedJoinFilter = LogicalRowExpressions.and((RowExpression[])new RowExpression[]{(RowExpression)joinNode.getFilter().orElse(LogicalRowExpressions.TRUE_CONSTANT), this.buildNotNullRowExpression(inferredNotNullVariables)});
        return Rule.Result.ofPlanNode((PlanNode)new JoinNode(joinNode.getSourceLocation(), context.getIdAllocator().getNextId(), joinNode.getType(), joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getOutputVariables(), Optional.ofNullable(updatedJoinFilter), joinNode.getLeftHashVariable(), joinNode.getRightHashVariable(), joinNode.getDistributionType(), joinNode.getDynamicFilters()));
    }

    private Collection<VariableReferenceExpression> extractNotNullVariables(List<EquiJoinClause> joinCriteria, Optional<RowExpression> joinFilter, List<VariableReferenceExpression> candidates, FeaturesConfig.JoinNotNullInferenceStrategy notNullInferenceStrategy) {
        ConstantExpression combinedFilter = LogicalRowExpressions.TRUE_CONSTANT;
        for (EquiJoinClause criteria : joinCriteria) {
            combinedFilter = LogicalRowExpressions.and((RowExpression[])new RowExpression[]{combinedFilter, criteria.getLeft()});
            combinedFilter = LogicalRowExpressions.and((RowExpression[])new RowExpression[]{combinedFilter, criteria.getRight()});
        }
        combinedFilter = LogicalRowExpressions.and((RowExpression[])new RowExpression[]{combinedFilter, joinFilter.orElse((RowExpression)LogicalRowExpressions.TRUE_CONSTANT)});
        return Sets.intersection((Set)ImmutableSet.copyOf(candidates), this.inferNotNullVariables((RowExpression)combinedFilter, notNullInferenceStrategy));
    }

    @VisibleForTesting
    Set<VariableReferenceExpression> getExistingNotNullVariables(Optional<RowExpression> joinFilter) {
        if (!joinFilter.isPresent()) {
            return ImmutableSet.of();
        }
        ImmutableSet.Builder builder = ImmutableSet.builder();
        DefaultRowExpressionTraversalVisitor<ImmutableSet.Builder<VariableReferenceExpression>> isNotNullExtractingVisitor = new DefaultRowExpressionTraversalVisitor<ImmutableSet.Builder<VariableReferenceExpression>>(){

            public Void visitCall(CallExpression call, ImmutableSet.Builder<VariableReferenceExpression> context) {
                if (AddNotNullFiltersToJoinNode.this.functionResolution.isNotFunction(call.getFunctionHandle()) && call.getArguments().size() == 1 && call.getArguments().get(0) instanceof SpecialFormExpression && ((SpecialFormExpression)call.getArguments().get(0)).getForm() == SpecialFormExpression.Form.IS_NULL && ((SpecialFormExpression)call.getArguments().get(0)).getArguments().size() == 1 && ((SpecialFormExpression)call.getArguments().get(0)).getArguments().get(0) instanceof VariableReferenceExpression) {
                    context.add((Object)((VariableReferenceExpression)((SpecialFormExpression)call.getArguments().get(0)).getArguments().get(0)));
                }
                return null;
            }

            public Void visitIntermediateFormExpression(IntermediateFormExpression expression, ImmutableSet.Builder<VariableReferenceExpression> context) {
                return null;
            }

            public Void visitSpecialForm(SpecialFormExpression specialForm, ImmutableSet.Builder<VariableReferenceExpression> context) {
                if (specialForm.getForm() == SpecialFormExpression.Form.AND) {
                    return super.visitSpecialForm(specialForm, context);
                }
                return null;
            }
        };
        joinFilter.get().accept((RowExpressionVisitor)isNotNullExtractingVisitor, (Object)builder);
        return builder.build();
    }

    private ImmutableSet<VariableReferenceExpression> inferNotNullVariables(RowExpression expression, FeaturesConfig.JoinNotNullInferenceStrategy notNullInferenceStrategy) {
        ImmutableSet.Builder builder = ImmutableSet.builder();
        expression.accept((RowExpressionVisitor)new ExtractInferredNotNullVariablesVisitor(this.functionAndTypeManager, notNullInferenceStrategy), (Object)builder);
        return builder.build();
    }

    private RowExpression buildNotNullRowExpression(Collection<VariableReferenceExpression> expressions) {
        List isNotNullExpressions = (List)expressions.stream().map(x -> new CallExpression(x.getSourceLocation(), "not", this.functionResolution.notFunction(), (Type)BooleanType.BOOLEAN, Collections.singletonList(new SpecialFormExpression(x.getSourceLocation(), SpecialFormExpression.Form.IS_NULL, (Type)BooleanType.BOOLEAN, new RowExpression[]{x})))).collect(ImmutableList.toImmutableList());
        return LogicalRowExpressions.and((Collection)isNotNullExpressions);
    }

    @VisibleForTesting
    public static class ExtractInferredNotNullVariablesVisitor
    extends DefaultRowExpressionTraversalVisitor<ImmutableSet.Builder<VariableReferenceExpression>> {
        private final FunctionAndTypeManager functionAndTypeManager;
        private final FeaturesConfig.JoinNotNullInferenceStrategy notNullInferenceStrategy;

        public ExtractInferredNotNullVariablesVisitor(FunctionAndTypeManager functionAndTypeManager, FeaturesConfig.JoinNotNullInferenceStrategy notNullInferenceStrategy) {
            this.functionAndTypeManager = functionAndTypeManager;
            this.notNullInferenceStrategy = notNullInferenceStrategy;
        }

        public Void visitCall(CallExpression call, ImmutableSet.Builder<VariableReferenceExpression> context) {
            FunctionHandle functionHandle = call.getFunctionHandle();
            FunctionMetadata functionMetadata = this.functionAndTypeManager.getFunctionMetadata(functionHandle);
            switch (this.notNullInferenceStrategy) {
                case INFER_FROM_STANDARD_OPERATORS: {
                    if (functionMetadata.getOperatorType().isPresent() && !((OperatorType)functionMetadata.getOperatorType().get()).isCalledOnNullInput()) break;
                    return null;
                }
                case USE_FUNCTION_METADATA: {
                    if (!functionMetadata.isCalledOnNullInput()) break;
                    return null;
                }
                default: {
                    return null;
                }
            }
            return super.visitCall(call, context);
        }

        public Void visitSpecialForm(SpecialFormExpression specialForm, ImmutableSet.Builder<VariableReferenceExpression> context) {
            SpecialFormExpression.Form form = specialForm.getForm();
            if (form == SpecialFormExpression.Form.AND) {
                return super.visitSpecialForm(specialForm, context);
            }
            return null;
        }

        public Void visitIntermediateFormExpression(IntermediateFormExpression expression, ImmutableSet.Builder<VariableReferenceExpression> context) {
            return null;
        }

        public Void visitVariableReference(VariableReferenceExpression variableReferenceExpression, ImmutableSet.Builder<VariableReferenceExpression> context) {
            context.add((Object)variableReferenceExpression);
            return super.visitVariableReference(variableReferenceExpression, context);
        }
    }
}

