package org.apache.doris.nereids.rules.expression.rules;

import com.google.common.base.Preconditions;
import java.math.BigDecimal;
import java.math.RoundingMode;
import org.apache.doris.analysis.IntLiteral;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.DateLikeType;

/* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.class */
public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
    public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate$AdjustType.class */
    public enum AdjustType {
        LOWER,
        UPPER,
        NONE
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitComparisonPredicate(ComparisonPredicate comparisonPredicate, ExpressionRewriteContext expressionRewriteContext) {
        ComparisonPredicate comparisonPredicate2 = (ComparisonPredicate) visit((Expression) comparisonPredicate, (ComparisonPredicate) expressionRewriteContext);
        if ((comparisonPredicate2.left() instanceof Literal) && !(comparisonPredicate2.right() instanceof Literal)) {
            comparisonPredicate2 = comparisonPredicate2.commute();
        }
        Expression left = comparisonPredicate2.left();
        Expression right = comparisonPredicate2.right();
        return (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) ? processFloatLikeTypeCoercion(comparisonPredicate2, left, right) : ((left.getDataType() instanceof DecimalV3Type) && (right.getDataType() instanceof DecimalV3Type)) ? processDecimalV3TypeCoercion(comparisonPredicate2, left, right) : ((left.getDataType() instanceof DateLikeType) && (right.getDataType() instanceof DateLikeType)) ? processDateLikeTypeCoercion(comparisonPredicate2, left, right) : comparisonPredicate2;
    }

    private static Expression processComparisonPredicateDateTimeV2Literal(ComparisonPredicate comparisonPredicate, Expression expression, DateTimeV2Literal dateTimeV2Literal) {
        DateTimeV2Type dateTimeV2Type = (DateTimeV2Type) expression.getDataType();
        if (dateTimeV2Type.getScale() < dateTimeV2Literal.getDataType().getScale()) {
            int scale = dateTimeV2Type.getScale();
            if (comparisonPredicate instanceof EqualTo) {
                long microSecond = dateTimeV2Literal.getMicroSecond();
                DateTimeV2Literal roundCeiling = dateTimeV2Literal.roundCeiling(scale);
                return roundCeiling.getMicroSecond() == microSecond ? comparisonPredicate.withChildren(expression, roundCeiling) : expression.nullable() ? new And(new IsNull(expression), new NullLiteral(BooleanType.INSTANCE)) : BooleanLiteral.of(false);
            }
            if (comparisonPredicate instanceof NullSafeEqual) {
                long microSecond2 = dateTimeV2Literal.getMicroSecond();
                DateTimeV2Literal roundCeiling2 = dateTimeV2Literal.roundCeiling(scale);
                return roundCeiling2.getMicroSecond() == microSecond2 ? comparisonPredicate.withChildren(expression, roundCeiling2) : BooleanLiteral.of(false);
            }
            if ((comparisonPredicate instanceof GreaterThan) || (comparisonPredicate instanceof LessThanEqual)) {
                return comparisonPredicate.withChildren(expression, dateTimeV2Literal.roundFloor(scale));
            }
            if ((comparisonPredicate instanceof LessThan) || (comparisonPredicate instanceof GreaterThanEqual)) {
                return comparisonPredicate.withChildren(expression, dateTimeV2Literal.roundCeiling(scale));
            }
        }
        return comparisonPredicate;
    }

    private Expression processDateLikeTypeCoercion(ComparisonPredicate comparisonPredicate, Expression expression, Expression expression2) {
        if ((expression instanceof Cast) && (expression2 instanceof DateLiteral)) {
            Cast cast = (Cast) expression;
            if ((cast.child().getDataType() instanceof DateTimeType) && (expression2 instanceof DateTimeV2Literal)) {
                expression = cast.child();
                expression2 = migrateToDateTime((DateTimeV2Literal) expression2);
            }
            if ((cast.child().getDataType() instanceof DateTimeV2Type) && (expression2 instanceof DateTimeV2Literal)) {
                return processComparisonPredicateDateTimeV2Literal(comparisonPredicate, cast.child(), (DateTimeV2Literal) expression2);
            }
            if (((cast.child().getDataType() instanceof DateType) || (cast.child().getDataType() instanceof DateV2Type)) && (expression2 instanceof DateTimeLiteral)) {
                if (cannotAdjust((DateTimeLiteral) expression2, comparisonPredicate)) {
                    return comparisonPredicate;
                }
                AdjustType adjustType = AdjustType.NONE;
                if ((comparisonPredicate instanceof GreaterThanEqual) || (comparisonPredicate instanceof LessThan)) {
                    adjustType = AdjustType.UPPER;
                } else if ((comparisonPredicate instanceof GreaterThan) || (comparisonPredicate instanceof LessThanEqual)) {
                    adjustType = AdjustType.LOWER;
                }
                expression2 = migrateToDateV2((DateTimeLiteral) expression2, adjustType);
                if (cast.child().getDataType() instanceof DateV2Type) {
                    expression = cast.child();
                }
            }
            if ((cast.child().getDataType() instanceof DateType) && (expression2 instanceof DateV2Literal)) {
                expression = cast.child();
                expression2 = migrateToDate((DateV2Literal) expression2);
            }
        }
        return (expression == comparisonPredicate.left() && expression2 == comparisonPredicate.right()) ? comparisonPredicate : comparisonPredicate.withChildren(expression, expression2);
    }

    private Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPredicate, Expression expression, Expression expression2) {
        return ((expression instanceof Cast) && expression.child(0).getDataType().isIntegerLikeType() && ((expression2 instanceof DoubleLiteral) || (expression2 instanceof FloatLiteral))) ? processIntegerDecimalLiteralComparison(comparisonPredicate, ((Cast) expression).child(), new BigDecimal(((Literal) expression2).getStringValue())) : comparisonPredicate;
    }

    private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate, Expression expression, Expression expression2) {
        if ((expression instanceof Cast) && (expression2 instanceof DecimalV3Literal)) {
            Expression child = ((Cast) expression).child();
            DecimalV3Literal decimalV3Literal = (DecimalV3Literal) expression2;
            if (child.getDataType().isDecimalV3Type()) {
                if (((DecimalV3Type) child.getDataType()).getScale() < ((DecimalV3Type) decimalV3Literal.getDataType()).getScale()) {
                    int scale = ((DecimalV3Type) child.getDataType()).getScale();
                    if (comparisonPredicate instanceof EqualTo) {
                        try {
                            return comparisonPredicate.withChildren(child, new DecimalV3Literal((DecimalV3Type) child.getDataType(), decimalV3Literal.getValue().setScale(scale)));
                        } catch (ArithmeticException e) {
                            return child.nullable() ? new And(new IsNull(child), new NullLiteral(BooleanType.INSTANCE)) : BooleanLiteral.of(false);
                        }
                    }
                    if (comparisonPredicate instanceof NullSafeEqual) {
                        try {
                            return comparisonPredicate.withChildren(child, new DecimalV3Literal((DecimalV3Type) child.getDataType(), decimalV3Literal.getValue().setScale(scale)));
                        } catch (ArithmeticException e2) {
                            return BooleanLiteral.of(false);
                        }
                    }
                    if ((comparisonPredicate instanceof GreaterThan) || (comparisonPredicate instanceof LessThanEqual)) {
                        return comparisonPredicate.withChildren(child, decimalV3Literal.roundFloor(scale));
                    }
                    if ((comparisonPredicate instanceof LessThan) || (comparisonPredicate instanceof GreaterThanEqual)) {
                        return comparisonPredicate.withChildren(child, decimalV3Literal.roundCeiling(scale));
                    }
                }
            } else if (child.getDataType().isIntegerLikeType()) {
                return processIntegerDecimalLiteralComparison(comparisonPredicate, child, decimalV3Literal.getValue());
            }
        }
        return comparisonPredicate;
    }

    private Expression processIntegerDecimalLiteralComparison(ComparisonPredicate comparisonPredicate, Expression expression, BigDecimal bigDecimal) {
        if (bigDecimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) {
            if (bigDecimal.scale() <= 0) {
                return comparisonPredicate.withChildren(expression, convertDecimalToIntegerLikeLiteral(bigDecimal));
            }
            if (comparisonPredicate instanceof EqualTo) {
                return expression.nullable() ? new And(new IsNull(expression), new NullLiteral(BooleanType.INSTANCE)) : BooleanLiteral.of(false);
            }
            if (comparisonPredicate instanceof NullSafeEqual) {
                return BooleanLiteral.of(false);
            }
            if ((comparisonPredicate instanceof GreaterThan) || (comparisonPredicate instanceof LessThanEqual)) {
                return comparisonPredicate.withChildren(expression, convertDecimalToIntegerLikeLiteral(bigDecimal.setScale(0, RoundingMode.FLOOR)));
            }
            if ((comparisonPredicate instanceof LessThan) || (comparisonPredicate instanceof GreaterThanEqual)) {
                return comparisonPredicate.withChildren(expression, convertDecimalToIntegerLikeLiteral(bigDecimal.setScale(0, RoundingMode.CEILING)));
            }
        }
        return comparisonPredicate;
    }

    private IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal bigDecimal) {
        Preconditions.checkArgument(bigDecimal.scale() == 0 && bigDecimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, "decimal literal must have 0 scale and smaller than Long.MAX_VALUE");
        long longValue = bigDecimal.longValue();
        return longValue <= 127 ? new TinyIntLiteral((byte) longValue) : longValue <= IntLiteral.SMALL_INT_MAX ? new SmallIntLiteral((short) longValue) : longValue <= IntLiteral.INT_MAX ? new IntegerLiteral((int) longValue) : new BigIntLiteral(longValue);
    }

    private Expression migrateToDateTime(DateTimeV2Literal dateTimeV2Literal) {
        return new DateTimeLiteral(dateTimeV2Literal.getYear(), dateTimeV2Literal.getMonth(), dateTimeV2Literal.getDay(), dateTimeV2Literal.getHour(), dateTimeV2Literal.getMinute(), dateTimeV2Literal.getSecond());
    }

    private boolean cannotAdjust(DateTimeLiteral dateTimeLiteral, ComparisonPredicate comparisonPredicate) {
        return (comparisonPredicate instanceof EqualTo) && !(dateTimeLiteral.getHour() == 0 && dateTimeLiteral.getMinute() == 0 && dateTimeLiteral.getSecond() == 0);
    }

    private Expression migrateToDateV2(DateTimeLiteral dateTimeLiteral, AdjustType adjustType) {
        DateV2Literal dateV2Literal = new DateV2Literal(dateTimeLiteral.getYear(), dateTimeLiteral.getMonth(), dateTimeLiteral.getDay());
        if (adjustType == AdjustType.UPPER && (dateTimeLiteral.getHour() != 0 || dateTimeLiteral.getMinute() != 0 || dateTimeLiteral.getSecond() != 0)) {
            dateV2Literal = (DateV2Literal) dateV2Literal.plusDays(1L);
        }
        return dateV2Literal;
    }

    private Expression migrateToDate(DateV2Literal dateV2Literal) {
        return new DateLiteral(dateV2Literal.getYear(), dateV2Literal.getMonth(), dateV2Literal.getDay());
    }
}
