package org.apache.doris.nereids.rules.expression.rules;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.util.TypeUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.class */
public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule {
    public static final SimplifyArithmeticRule INSTANCE = new SimplifyArithmeticRule();

    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule$Operand.class */
    public static class Operand {
        boolean flag;
        Expression expression;

        public Operand(boolean z, Expression expression) {
            this.flag = z;
            this.expression = expression;
        }

        public static Operand of(boolean z, Expression expression) {
            return new Operand(z, expression);
        }

        public String toString() {
            return this.flag + " : " + this.expression;
        }
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitAdd(Add add, ExpressionRewriteContext expressionRewriteContext) {
        return process(add, true);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitSubtract(Subtract subtract, ExpressionRewriteContext expressionRewriteContext) {
        return process(subtract, true);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitDivide(Divide divide, ExpressionRewriteContext expressionRewriteContext) {
        return process(divide, false);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitMultiply(Multiply multiply, ExpressionRewriteContext expressionRewriteContext) {
        return process(multiply, false);
    }

    private Expression process(BinaryArithmetic binaryArithmetic, boolean z) {
        List<Operand> flatten = flatten(binaryArithmetic, z);
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList newArrayList2 = Lists.newArrayList();
        if (flatten.stream().anyMatch(operand -> {
            return operand.expression.getDataType().isDecimalLikeType();
        })) {
            return binaryArithmetic;
        }
        flatten.forEach(operand2 -> {
            if (operand2.expression.isConstant()) {
                newArrayList2.add(operand2);
            } else {
                newArrayList.add(operand2);
            }
        });
        if (!newArrayList2.isEmpty()) {
            boolean z2 = !((Operand) newArrayList2.get(0)).flag;
            Optional reduce = newArrayList2.stream().reduce((operand3, operand4) -> {
                return Operand.of(true, (!(z2 && operand4.flag) && (z2 || operand4.flag)) ? getAddOrMultiply(z, operand3, operand4) : getSubOrDivide(z, operand3, operand4));
            });
            boolean z3 = newArrayList.isEmpty() || ((Operand) newArrayList.get(0)).flag;
            if (z2 || z3) {
                newArrayList.add(Operand.of(!z2, ((Operand) reduce.get()).expression));
            } else {
                newArrayList.add(0, Operand.of(!z2, ((Operand) reduce.get()).expression));
            }
        }
        Optional reduce2 = newArrayList.stream().reduce((operand5, operand6) -> {
            return !operand6.flag ? Operand.of(true, getSubOrDivide(z, operand5, operand6)) : Operand.of(true, getAddOrMultiply(z, operand5, operand6));
        });
        return reduce2.isPresent() ? ((Operand) reduce2.get()).expression : binaryArithmetic;
    }

    private List<Operand> flatten(Expression expression, boolean z) {
        ArrayList newArrayList = Lists.newArrayList();
        if (z) {
            flattenAddSubtract(true, expression, newArrayList);
        } else {
            flattenMultiplyDivide(true, expression, newArrayList);
        }
        return newArrayList;
    }

    private void flattenAddSubtract(boolean z, Expression expression, List<Operand> list) {
        if (!TypeUtils.isAddOrSubtract(expression)) {
            list.add(Operand.of(z, expression));
            return;
        }
        BinaryArithmetic binaryArithmetic = (BinaryArithmetic) expression;
        flattenAddSubtract(z, binaryArithmetic.left(), list);
        if (TypeUtils.isSubtract(expression) && !z) {
            flattenAddSubtract(true, binaryArithmetic.right(), list);
        } else if (!TypeUtils.isAdd(expression) || z) {
            flattenAddSubtract(!TypeUtils.isSubtract(expression), binaryArithmetic.right(), list);
        } else {
            flattenAddSubtract(false, binaryArithmetic.right(), list);
        }
    }

    private void flattenMultiplyDivide(boolean z, Expression expression, List<Operand> list) {
        if (!TypeUtils.isMultiplyOrDivide(expression)) {
            list.add(Operand.of(z, expression));
            return;
        }
        BinaryArithmetic binaryArithmetic = (BinaryArithmetic) expression;
        flattenMultiplyDivide(z, binaryArithmetic.left(), list);
        if (TypeUtils.isDivide(expression) && !z) {
            flattenMultiplyDivide(true, binaryArithmetic.right(), list);
        } else if (!TypeUtils.isMultiply(expression) || z) {
            flattenMultiplyDivide(!TypeUtils.isDivide(expression), binaryArithmetic.right(), list);
        } else {
            flattenMultiplyDivide(false, binaryArithmetic.right(), list);
        }
    }

    private Expression getSubOrDivide(boolean z, Operand operand, Operand operand2) {
        return z ? new Subtract(operand.expression, operand2.expression) : new Divide(operand.expression, operand2.expression);
    }

    private Expression getAddOrMultiply(boolean z, Operand operand, Operand operand2) {
        return z ? new Add(operand.expression, operand2.expression) : new Multiply(operand.expression, operand2.expression);
    }
}
