package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.nereids.util.TypeUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.class */
public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteRule {
    public static final SimplifyArithmeticComparisonRule INSTANCE = new SimplifyArithmeticComparisonRule();

    @Override // org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter, org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visit(Expression expression, ExpressionRewriteContext expressionRewriteContext) {
        return expression;
    }

    private Expression process(ComparisonPredicate comparisonPredicate) {
        Expression left = comparisonPredicate.left();
        Expression right = comparisonPredicate.right();
        if (TypeUtils.isAddOrSubtract(left)) {
            Expression child = left.child(1);
            if (child.isConstant()) {
                if (TypeUtils.isAdd(left)) {
                    right = new Subtract(right, child);
                }
                if (TypeUtils.isSubtract(left)) {
                    right = new Add(right, child);
                }
                left = left.child(0);
            }
        }
        if (TypeUtils.isDivide(left)) {
            Expression child2 = left.child(1);
            if (child2.isLiteral()) {
                right = new Multiply(right, child2);
                left = left.child(0);
                if (child2.toString().startsWith("-")) {
                    right = left;
                    left = right;
                }
            }
        }
        return (left == comparisonPredicate.left() && right == comparisonPredicate.right()) ? comparisonPredicate : TypeCoercionUtils.processComparisonPredicate(comparisonPredicate, left, right);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext expressionRewriteContext) {
        return process(greaterThan);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext expressionRewriteContext) {
        return process(greaterThanEqual);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext expressionRewriteContext) {
        return process(equalTo);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext expressionRewriteContext) {
        return process(lessThan);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext expressionRewriteContext) {
        return process(lessThanEqual);
    }
}
