package org.apache.doris.nereids.rules.expression.rules;

import com.google.common.base.Preconditions;
import java.math.BigDecimal;
import org.apache.doris.analysis.DecimalLiteral;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.types.DecimalV3Type;

/* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.class */
public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule {
    public static SimplifyDecimalV3Comparison INSTANCE = new SimplifyDecimalV3Comparison();

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitComparisonPredicate(ComparisonPredicate comparisonPredicate, ExpressionRewriteContext expressionRewriteContext) {
        Expression rewrite = rewrite(comparisonPredicate.left(), expressionRewriteContext);
        Expression rewrite2 = rewrite(comparisonPredicate.right(), expressionRewriteContext);
        return ((rewrite.getDataType() instanceof DecimalV3Type) && (rewrite instanceof Cast) && (((Cast) rewrite).child().getDataType() instanceof DecimalV3Type) && (rewrite2 instanceof DecimalV3Literal)) ? doProcess(comparisonPredicate, (Cast) rewrite, (DecimalV3Literal) rewrite2) : (rewrite == comparisonPredicate.left() && rewrite2 == comparisonPredicate.right()) ? comparisonPredicate : comparisonPredicate.withChildren(rewrite, rewrite2);
    }

    private Expression doProcess(ComparisonPredicate comparisonPredicate, Cast cast, DecimalV3Literal decimalV3Literal) {
        BigDecimal stripTrailingZeros = decimalV3Literal.getValue().stripTrailingZeros();
        int bigDecimalScale = DecimalLiteral.getBigDecimalScale(stripTrailingZeros);
        int bigDecimalPrecision = DecimalLiteral.getBigDecimalPrecision(stripTrailingZeros);
        Expression child = cast.child();
        Preconditions.checkState(child.getDataType() instanceof DecimalV3Type);
        DecimalV3Type decimalV3Type = (DecimalV3Type) child.getDataType();
        return (bigDecimalScale > decimalV3Type.getScale() || bigDecimalPrecision - bigDecimalScale > decimalV3Type.getPrecision() - decimalV3Type.getScale()) ? comparisonPredicate : comparisonPredicate.withChildren(child, new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(decimalV3Type.getPrecision(), decimalV3Type.getScale()), stripTrailingZeros));
    }
}
