package org.apache.doris.nereids.rules.expression.rules;

import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayContains;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraysOverlap;
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.class */
public class ArrayContainToArrayOverlap extends DefaultExpressionRewriter<ExpressionRewriteContext> implements ExpressionRewriteRule<ExpressionRewriteContext> {
    public static final ArrayContainToArrayOverlap INSTANCE = new ArrayContainToArrayOverlap();
    private static final int REWRITE_PREDICATE_THRESHOLD = 2;

    @Override // org.apache.doris.nereids.rules.expression.ExpressionRewriteRule
    public Expression rewrite(Expression expression, ExpressionRewriteContext expressionRewriteContext) {
        return (Expression) expression.accept(this, expressionRewriteContext);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitOr(Or or, ExpressionRewriteContext expressionRewriteContext) {
        Map map = (Map) ExpressionUtils.extractDisjunction(or).stream().collect(Collectors.partitioningBy(this::isValidArrayContains));
        HashMap hashMap = new HashMap();
        List list = (List) map.get(true);
        List list2 = (List) map.get(false);
        list.forEach(expression -> {
            ((Set) hashMap.computeIfAbsent(expression.child(0), expression -> {
                return new HashSet();
            })).add((Literal) expression.child(1));
        });
        ImmutableList.Builder builder = new ImmutableList.Builder();
        hashMap.forEach((expression2, set) -> {
            if (set.size() > 2) {
                builder.add(new ArraysOverlap(expression2, new ArrayLiteral(ImmutableList.copyOf(set))));
            }
        });
        Stream filter = list.stream().filter(expression3 -> {
            return !canCovertToArrayOverlap(expression3, hashMap);
        });
        builder.getClass();
        filter.forEach((v1) -> {
            r1.add(v1);
        });
        Stream map2 = list2.stream().map(expression4 -> {
            return (Expression) expression4.accept(this, null);
        });
        builder.getClass();
        map2.forEach((v1) -> {
            r1.add(v1);
        });
        return ExpressionUtils.or((Collection<Expression>) builder.build());
    }

    private boolean isValidArrayContains(Expression expression) {
        return (expression instanceof ArrayContains) && (expression.child(1) instanceof Literal);
    }

    private boolean canCovertToArrayOverlap(Expression expression, Map<Expression, Set<Literal>> map) {
        return (expression instanceof ArrayContains) && map.getOrDefault(expression.child(0), new HashSet()).size() > 2;
    }
}
