package org.apache.doris.nereids.rules.expression.rules;

import java.math.BigDecimal;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;

/* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.class */
public class SimplifyCastRule extends AbstractExpressionRewriteRule {
    public static SimplifyCastRule INSTANCE = new SimplifyCastRule();

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitCast(Cast cast, ExpressionRewriteContext expressionRewriteContext) {
        return simplify(cast, expressionRewriteContext);
    }

    private Expression simplify(Cast cast, ExpressionRewriteContext expressionRewriteContext) {
        Expression rewrite = rewrite(cast.child(), expressionRewriteContext);
        if (cast.getDataType().equals(rewrite.getDataType())) {
            return rewrite;
        }
        if (rewrite instanceof Literal) {
            try {
                DataType dataType = cast.getDataType();
                if (dataType instanceof StringType) {
                    if (rewrite instanceof VarcharLiteral) {
                        return new StringLiteral(((VarcharLiteral) rewrite).getValue());
                    }
                    if (rewrite instanceof CharLiteral) {
                        return new StringLiteral(((CharLiteral) rewrite).getValue());
                    }
                } else if (dataType instanceof VarcharType) {
                    if (rewrite instanceof VarcharLiteral) {
                        return new VarcharLiteral(((VarcharLiteral) rewrite).getValue(), ((VarcharType) dataType).getLen());
                    }
                    if (rewrite instanceof CharLiteral) {
                        return new VarcharLiteral(((CharLiteral) rewrite).getValue(), ((VarcharType) dataType).getLen());
                    }
                } else if (dataType instanceof DecimalV2Type) {
                    if (rewrite instanceof TinyIntLiteral) {
                        return new DecimalLiteral(new BigDecimal((int) ((TinyIntLiteral) rewrite).getValue().byteValue()));
                    }
                    if (rewrite instanceof SmallIntLiteral) {
                        return new DecimalLiteral(new BigDecimal((int) ((SmallIntLiteral) rewrite).getValue().shortValue()));
                    }
                    if (rewrite instanceof IntegerLiteral) {
                        return new DecimalLiteral(new BigDecimal(((IntegerLiteral) rewrite).getValue().intValue()));
                    }
                    if (rewrite instanceof BigIntLiteral) {
                        return new DecimalLiteral(new BigDecimal(((BigIntLiteral) rewrite).getValue().longValue()));
                    }
                } else if (dataType instanceof DecimalV3Type) {
                    DecimalV3Type decimalV3Type = (DecimalV3Type) dataType;
                    if (rewrite instanceof TinyIntLiteral) {
                        return new DecimalV3Literal(decimalV3Type, new BigDecimal((int) ((TinyIntLiteral) rewrite).getValue().byteValue()));
                    }
                    if (rewrite instanceof SmallIntLiteral) {
                        return new DecimalV3Literal(decimalV3Type, new BigDecimal((int) ((SmallIntLiteral) rewrite).getValue().shortValue()));
                    }
                    if (rewrite instanceof IntegerLiteral) {
                        return new DecimalV3Literal(decimalV3Type, new BigDecimal(((IntegerLiteral) rewrite).getValue().intValue()));
                    }
                    if (rewrite instanceof BigIntLiteral) {
                        return new DecimalV3Literal(decimalV3Type, new BigDecimal(((BigIntLiteral) rewrite).getValue().longValue()));
                    }
                    if (rewrite instanceof DecimalV3Literal) {
                        return new DecimalV3Literal(decimalV3Type, ((DecimalV3Literal) rewrite).getValue());
                    }
                }
            } catch (Throwable th) {
                return cast;
            }
        }
        return rewrite != cast.child() ? new Cast(rewrite, cast.getDataType()) : cast;
    }
}
