package org.apache.doris.nereids.rules.expression.rules;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/OrToIn.class */
public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext> implements ExpressionRewriteRule<ExpressionRewriteContext> {
    public static final OrToIn INSTANCE = new OrToIn();
    private static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2;

    @Override // org.apache.doris.nereids.rules.expression.ExpressionRewriteRule
    public Expression rewrite(Expression expression, ExpressionRewriteContext expressionRewriteContext) {
        return (Expression) expression.accept(this, null);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitOr(Or or, ExpressionRewriteContext expressionRewriteContext) {
        HashMap hashMap = new HashMap();
        List<Expression> extractDisjunction = ExpressionUtils.extractDisjunction(or);
        for (Expression expression : extractDisjunction) {
            if (expression instanceof EqualTo) {
                addSlotToLiteralMap((EqualTo) expression, hashMap);
            }
        }
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<NamedExpression, Set<Literal>> entry : hashMap.entrySet()) {
            if (entry.getValue().size() >= 2) {
                arrayList.add(new InPredicate(entry.getKey(), ImmutableList.copyOf(entry.getValue())));
            }
        }
        for (Expression expression2 : extractDisjunction) {
            if (!ableToConvertToIn(expression2, hashMap)) {
                arrayList.add(expression2.accept(this, null));
            }
        }
        return ExpressionUtils.or(arrayList);
    }

    private void addSlotToLiteralMap(EqualTo equalTo, Map<NamedExpression, Set<Literal>> map) {
        Expression left = equalTo.left();
        Expression right = equalTo.right();
        if ((left instanceof NamedExpression) && (right instanceof Literal)) {
            addSlotToLiteral((NamedExpression) left, (Literal) right, map);
        }
        if ((right instanceof NamedExpression) && (left instanceof Literal)) {
            addSlotToLiteral((NamedExpression) right, (Literal) left, map);
        }
    }

    private boolean ableToConvertToIn(Expression expression, Map<NamedExpression, Set<Literal>> map) {
        if (!(expression instanceof EqualTo)) {
            return false;
        }
        EqualTo equalTo = (EqualTo) expression;
        Expression left = equalTo.left();
        Expression right = equalTo.right();
        NamedExpression namedExpression = null;
        if ((left instanceof NamedExpression) && (right instanceof Literal)) {
            namedExpression = (NamedExpression) left;
        }
        if ((right instanceof NamedExpression) && (left instanceof Literal)) {
            namedExpression = (NamedExpression) right;
        }
        return namedExpression != null && findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression, map) >= 2;
    }

    public void addSlotToLiteral(NamedExpression namedExpression, Literal literal, Map<NamedExpression, Set<Literal>> map) {
        map.computeIfAbsent(namedExpression, namedExpression2 -> {
            return new HashSet();
        }).add(literal);
    }

    public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression namedExpression, Map<NamedExpression, Set<Literal>> map) {
        return map.getOrDefault(namedExpression, Collections.emptySet()).size();
    }
}
