package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/PullUpPredicates.class */
public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void> {
    PredicatePropagation propagation = new PredicatePropagation();
    Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap();

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public ImmutableSet<Expression> visit(Plan plan, Void r6) {
        return plan.arity() == 1 ? (ImmutableSet) plan.child(0).accept(this, r6) : ImmutableSet.of();
    }

    /* renamed from: visitLogicalFilter, reason: avoid collision after fix types in other method */
    public ImmutableSet<Expression> visitLogicalFilter2(LogicalFilter<? extends Plan> logicalFilter, Void r8) {
        return cacheOrElse(logicalFilter, () -> {
            ArrayList newArrayList = Lists.newArrayList(logicalFilter.getConjuncts());
            newArrayList.addAll((Collection) ((Plan) logicalFilter.child()).accept(this, r8));
            return getAvailableExpressions(newArrayList, logicalFilter);
        });
    }

    /* renamed from: visitLogicalJoin, reason: avoid collision after fix types in other method */
    public ImmutableSet<Expression> visitLogicalJoin2(LogicalJoin<? extends Plan, ? extends Plan> logicalJoin, Void r8) {
        return cacheOrElse(logicalJoin, () -> {
            HashSet newHashSet = Sets.newHashSet();
            ImmutableSet immutableSet = (ImmutableSet) logicalJoin.left().accept(this, r8);
            ImmutableSet immutableSet2 = (ImmutableSet) logicalJoin.right().accept(this, r8);
            newHashSet.addAll(immutableSet);
            newHashSet.addAll(immutableSet2);
            return getAvailableExpressions(newHashSet, logicalJoin);
        });
    }

    /* renamed from: visitLogicalProject, reason: avoid collision after fix types in other method */
    public ImmutableSet<Expression> visitLogicalProject2(LogicalProject<? extends Plan> logicalProject, Void r8) {
        return cacheOrElse(logicalProject, () -> {
            ImmutableSet immutableSet = (ImmutableSet) ((Plan) logicalProject.child()).accept(this, r8);
            HashSet newHashSet = Sets.newHashSet(immutableSet);
            logicalProject.getAliasToProducer().forEach((slot, expression) -> {
                newHashSet.addAll((Set) immutableSet.stream().map(expression -> {
                    return expression.rewriteDownShortCircuit(expression -> {
                        return expression.equals(expression) ? slot : expression;
                    });
                }).collect(Collectors.toSet()));
            });
            return getAvailableExpressions(newHashSet, logicalProject);
        });
    }

    /* renamed from: visitLogicalAggregate, reason: avoid collision after fix types in other method */
    public ImmutableSet<Expression> visitLogicalAggregate2(LogicalAggregate<? extends Plan> logicalAggregate, Void r8) {
        return cacheOrElse(logicalAggregate, () -> {
            ImmutableSet immutableSet = (ImmutableSet) ((Plan) logicalAggregate.child()).accept(this, r8);
            return getAvailableExpressions(ExpressionUtils.extractConjunction(ExpressionUtils.replace(ExpressionUtils.and(Lists.newArrayList(immutableSet)), (Map<? extends Expression, ? extends Expression>) logicalAggregate.getOutputExpressions().stream().filter((v1) -> {
                return hasAgg(v1);
            }).collect(Collectors.toMap(namedExpression -> {
                return namedExpression instanceof Alias ? ((Alias) namedExpression).child() : namedExpression;
            }, (v0) -> {
                return v0.toSlot();
            })))), logicalAggregate);
        });
    }

    private ImmutableSet<Expression> cacheOrElse(Plan plan, Supplier<ImmutableSet<Expression>> supplier) {
        ImmutableSet<Expression> immutableSet = this.cache.get(plan);
        if (immutableSet != null) {
            return immutableSet;
        }
        ImmutableSet<Expression> immutableSet2 = supplier.get();
        this.cache.put(plan, immutableSet2);
        return immutableSet2;
    }

    private ImmutableSet<Expression> getAvailableExpressions(Collection<Expression> collection, Plan plan) {
        HashSet newHashSet = Sets.newHashSet(collection);
        newHashSet.addAll(this.propagation.infer(newHashSet));
        return (ImmutableSet) newHashSet.stream().filter(expression -> {
            return plan.getOutputSet().containsAll(expression.getInputSlots());
        }).collect(ImmutableSet.toImmutableSet());
    }

    private boolean hasAgg(Expression expression) {
        Class<AggregateFunction> cls = AggregateFunction.class;
        AggregateFunction.class.getClass();
        return expression.anyMatch((v1) -> {
            return r1.isInstance(v1);
        });
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ ImmutableSet<Expression> visitLogicalProject(LogicalProject logicalProject, Void r6) {
        return visitLogicalProject2((LogicalProject<? extends Plan>) logicalProject, r6);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ ImmutableSet<Expression> visitLogicalJoin(LogicalJoin logicalJoin, Void r6) {
        return visitLogicalJoin2((LogicalJoin<? extends Plan, ? extends Plan>) logicalJoin, r6);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ ImmutableSet<Expression> visitLogicalFilter(LogicalFilter logicalFilter, Void r6) {
        return visitLogicalFilter2((LogicalFilter<? extends Plan>) logicalFilter, r6);
    }

    @Override // org.apache.doris.nereids.trees.plans.visitor.PlanVisitor
    public /* bridge */ /* synthetic */ ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate logicalAggregate, Void r6) {
        return visitLogicalAggregate2((LogicalAggregate<? extends Plan>) logicalAggregate, r6);
    }
}
