package org.apache.doris.nereids.rules.expression.rules;

import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Range;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyRange.class */
public class SimplifyRange extends AbstractExpressionRewriteRule {
    public static final SimplifyRange INSTANCE = new SimplifyRange();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyRange$DiscreteValue.class */
    public static class DiscreteValue extends ValueDesc {
        Set<Literal> values;

        public DiscreteValue(Expression expression, Expression expression2, Literal... literalArr) {
            this(expression, expression2, Arrays.asList(literalArr));
        }

        public DiscreteValue(Expression expression, Expression expression2, Collection<Literal> collection) {
            super(expression, expression2);
            this.values = Sets.newTreeSet(collection);
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public ValueDesc union(ValueDesc valueDesc) {
            if (valueDesc instanceof EmptyValue) {
                return valueDesc.union(this);
            }
            try {
                if (!(valueDesc instanceof DiscreteValue)) {
                    return union((RangeValue) valueDesc, this, true);
                }
                DiscreteValue discreteValue = new DiscreteValue(this.reference, ExpressionUtils.or(this.expr, valueDesc.expr), new Literal[0]);
                discreteValue.values.addAll(((DiscreteValue) valueDesc).values);
                discreteValue.values.addAll(this.values);
                return discreteValue;
            } catch (Exception e) {
                return new UnknownValue(ImmutableList.of(this, valueDesc), ExpressionUtils.or(this.expr, valueDesc.expr), (expression, expression2) -> {
                    return ExpressionUtils.or(expression, expression2);
                });
            }
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public ValueDesc intersect(ValueDesc valueDesc) {
            if (valueDesc instanceof EmptyValue) {
                return valueDesc.intersect(this);
            }
            try {
                if (!(valueDesc instanceof DiscreteValue)) {
                    return intersect((RangeValue) valueDesc, this);
                }
                DiscreteValue discreteValue = new DiscreteValue(this.reference, ExpressionUtils.and(this.expr, valueDesc.expr), new Literal[0]);
                discreteValue.values.addAll(((DiscreteValue) valueDesc).values);
                discreteValue.values.retainAll(this.values);
                return discreteValue.values.isEmpty() ? new EmptyValue(this.reference, ExpressionUtils.and(this.expr, valueDesc.expr)) : discreteValue;
            } catch (Exception e) {
                return new UnknownValue(ImmutableList.of(this, valueDesc), ExpressionUtils.and(this.expr, valueDesc.expr), (expression, expression2) -> {
                    return ExpressionUtils.and(expression, expression2);
                });
            }
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public Expression toExpression() {
            if (this.values.size() == 1) {
                return new EqualTo(this.reference, this.values.iterator().next());
            }
            if (this.values.size() != 2) {
                return new InPredicate(this.reference, Lists.newArrayList(this.values));
            }
            Iterator<Literal> it = this.values.iterator();
            return new Or(new EqualTo(this.reference, it.next()), new EqualTo(this.reference, it.next()));
        }

        public String toString() {
            return this.values.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyRange$EmptyValue.class */
    public static class EmptyValue extends ValueDesc {
        public EmptyValue(Expression expression, Expression expression2) {
            super(expression, expression2);
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public ValueDesc union(ValueDesc valueDesc) {
            return valueDesc;
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public ValueDesc intersect(ValueDesc valueDesc) {
            return this;
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public Expression toExpression() {
            return BooleanLiteral.FALSE;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyRange$RangeInference.class */
    public static class RangeInference extends ExpressionVisitor<ValueDesc, Void> {
        private RangeInference() {
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public ValueDesc visit(Expression expression, Void r7) {
            return new UnknownValue(expression);
        }

        private ValueDesc buildRange(ComparisonPredicate comparisonPredicate) {
            Expression normalize = ExpressionRuleExecutor.normalize(comparisonPredicate);
            Expression child = normalize.child(1);
            return (child.isLiteral() && child.getDataType().isNumericType()) ? ValueDesc.range((ComparisonPredicate) normalize) : new UnknownValue(comparisonPredicate);
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public ValueDesc visitGreaterThan(GreaterThan greaterThan, Void r5) {
            return buildRange(greaterThan);
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public ValueDesc visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, Void r5) {
            return buildRange(greaterThanEqual);
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public ValueDesc visitLessThan(LessThan lessThan, Void r5) {
            return buildRange(lessThan);
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public ValueDesc visitLessThanEqual(LessThanEqual lessThanEqual, Void r5) {
            return buildRange(lessThanEqual);
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public ValueDesc visitEqualTo(EqualTo equalTo, Void r5) {
            return buildRange(equalTo);
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public ValueDesc visitInPredicate(InPredicate inPredicate, Void r7) {
            return (ExpressionUtils.isAllLiteral(inPredicate.getOptions()) && ExpressionUtils.matchNumericType(inPredicate.getOptions())) ? ValueDesc.discrete(inPredicate) : new UnknownValue(inPredicate);
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public ValueDesc visitAnd(And and, Void r8) {
            return simplify(and, ExpressionUtils.extractConjunction(and), (v0, v1) -> {
                return v0.intersect(v1);
            }, (expression, expression2) -> {
                return ExpressionUtils.and(expression, expression2);
            });
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public ValueDesc visitOr(Or or, Void r8) {
            return simplify(or, ExpressionUtils.extractDisjunction(or), (v0, v1) -> {
                return v0.union(v1);
            }, (expression, expression2) -> {
                return ExpressionUtils.or(expression, expression2);
            });
        }

        private ValueDesc simplify(Expression expression, List<Expression> list, BinaryOperator<ValueDesc> binaryOperator, BinaryOperator<Expression> binaryOperator2) {
            Map map = (Map) list.stream().map(expression2 -> {
                return (ValueDesc) expression2.accept(this, null);
            }).collect(Collectors.groupingBy(valueDesc -> {
                return valueDesc.reference;
            }, LinkedHashMap::new, Collectors.toList()));
            ArrayList newArrayList = Lists.newArrayList();
            Iterator it = map.entrySet().iterator();
            while (it.hasNext()) {
                newArrayList.add((ValueDesc) ((List) ((Map.Entry) it.next()).getValue()).stream().reduce(binaryOperator).get());
            }
            return newArrayList.size() == 1 ? (ValueDesc) newArrayList.get(0) : new UnknownValue(newArrayList, expression, binaryOperator2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyRange$RangeValue.class */
    public static class RangeValue extends ValueDesc {
        Range<Literal> range;

        public RangeValue(Expression expression, Expression expression2) {
            super(expression, expression2);
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public ValueDesc union(ValueDesc valueDesc) {
            if (valueDesc instanceof EmptyValue) {
                return valueDesc.union(this);
            }
            try {
                if (!(valueDesc instanceof RangeValue)) {
                    return union(this, (DiscreteValue) valueDesc, false);
                }
                RangeValue rangeValue = (RangeValue) valueDesc;
                if (!this.range.isConnected(rangeValue.range)) {
                    return new UnknownValue(ImmutableList.of(this, valueDesc), ExpressionUtils.or(this.expr, valueDesc.expr), (expression, expression2) -> {
                        return ExpressionUtils.or(expression, expression2);
                    });
                }
                RangeValue rangeValue2 = new RangeValue(this.reference, ExpressionUtils.or(this.expr, valueDesc.expr));
                rangeValue2.range = this.range.span(rangeValue.range);
                return rangeValue2;
            } catch (Exception e) {
                return new UnknownValue(ImmutableList.of(this, valueDesc), ExpressionUtils.or(this.expr, valueDesc.expr), (expression3, expression4) -> {
                    return ExpressionUtils.or(expression3, expression4);
                });
            }
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public ValueDesc intersect(ValueDesc valueDesc) {
            if (valueDesc instanceof EmptyValue) {
                return valueDesc.intersect(this);
            }
            try {
                if (!(valueDesc instanceof RangeValue)) {
                    return intersect(this, (DiscreteValue) valueDesc);
                }
                RangeValue rangeValue = (RangeValue) valueDesc;
                if (!this.range.isConnected(rangeValue.range)) {
                    return new EmptyValue(this.reference, ExpressionUtils.and(this.expr, valueDesc.expr));
                }
                RangeValue rangeValue2 = new RangeValue(this.reference, ExpressionUtils.and(this.expr, valueDesc.expr));
                rangeValue2.range = this.range.intersection(rangeValue.range);
                return rangeValue2;
            } catch (Exception e) {
                return new UnknownValue(ImmutableList.of(this, valueDesc), ExpressionUtils.and(this.expr, valueDesc.expr), (expression, expression2) -> {
                    return ExpressionUtils.and(expression, expression2);
                });
            }
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public Expression toExpression() {
            ArrayList newArrayList = Lists.newArrayList();
            if (this.range.hasLowerBound()) {
                if (this.range.lowerBoundType() == BoundType.CLOSED) {
                    newArrayList.add(new GreaterThanEqual(this.reference, (Expression) this.range.lowerEndpoint()));
                } else {
                    newArrayList.add(new GreaterThan(this.reference, (Expression) this.range.lowerEndpoint()));
                }
            }
            if (this.range.hasUpperBound()) {
                if (this.range.upperBoundType() == BoundType.CLOSED) {
                    newArrayList.add(new LessThanEqual(this.reference, (Expression) this.range.upperEndpoint()));
                } else {
                    newArrayList.add(new LessThan(this.reference, (Expression) this.range.upperEndpoint()));
                }
            }
            return newArrayList.isEmpty() ? BooleanLiteral.TRUE : ExpressionUtils.and(newArrayList);
        }

        public String toString() {
            return this.range == null ? "UnknwonRange" : this.range.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyRange$UnknownValue.class */
    public static class UnknownValue extends ValueDesc {
        private final List<ValueDesc> sourceValues;
        private final BinaryOperator<Expression> mergeExprOp;

        private UnknownValue(Expression expression) {
            super(expression, expression);
            this.sourceValues = ImmutableList.of();
            this.mergeExprOp = null;
        }

        public UnknownValue(List<ValueDesc> list, Expression expression, BinaryOperator<Expression> binaryOperator) {
            super(list.get(0).reference, expression);
            this.sourceValues = ImmutableList.copyOf(list);
            this.mergeExprOp = binaryOperator;
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public ValueDesc union(ValueDesc valueDesc) {
            return new UnknownValue(ImmutableList.of(this, valueDesc), ExpressionUtils.or(this.expr, valueDesc.expr), (expression, expression2) -> {
                return ExpressionUtils.or(expression, expression2);
            });
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public ValueDesc intersect(ValueDesc valueDesc) {
            return new UnknownValue(ImmutableList.of(this, valueDesc), ExpressionUtils.and(this.expr, valueDesc.expr), (expression, expression2) -> {
                return ExpressionUtils.and(expression, expression2);
            });
        }

        @Override // org.apache.doris.nereids.rules.expression.rules.SimplifyRange.ValueDesc
        public Expression toExpression() {
            return this.sourceValues.isEmpty() ? this.expr : (Expression) this.sourceValues.stream().map((v0) -> {
                return v0.toExpression();
            }).reduce(this.mergeExprOp).get();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/SimplifyRange$ValueDesc.class */
    public static abstract class ValueDesc {
        Expression expr;
        Expression reference;

        public ValueDesc(Expression expression, Expression expression2) {
            this.expr = expression2;
            this.reference = expression;
        }

        public abstract ValueDesc union(ValueDesc valueDesc);

        public static ValueDesc union(RangeValue rangeValue, DiscreteValue discreteValue, boolean z) {
            if (discreteValue.values.stream().filter(literal -> {
                return rangeValue.range.test(literal);
            }).count() == discreteValue.values.size()) {
                return rangeValue;
            }
            return new UnknownValue(z ? ImmutableList.of(discreteValue, rangeValue) : ImmutableList.of(rangeValue, discreteValue), ExpressionUtils.or(rangeValue.expr, discreteValue.expr), (expression, expression2) -> {
                return ExpressionUtils.or(expression, expression2);
            });
        }

        public abstract ValueDesc intersect(ValueDesc valueDesc);

        public static ValueDesc intersect(RangeValue rangeValue, DiscreteValue discreteValue) {
            DiscreteValue discreteValue2 = new DiscreteValue(discreteValue.reference, discreteValue.expr, new Literal[0]);
            Stream<Literal> filter = discreteValue.values.stream().filter(literal -> {
                return rangeValue.range.contains(literal);
            });
            Set<Literal> set = discreteValue2.values;
            set.getClass();
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            return discreteValue2.values.size() > 0 ? discreteValue2 : new EmptyValue(rangeValue.reference, ExpressionUtils.and(rangeValue.expr, discreteValue.expr));
        }

        public abstract Expression toExpression();

        public static ValueDesc range(ComparisonPredicate comparisonPredicate) {
            Literal literal = (Literal) comparisonPredicate.right();
            if (comparisonPredicate instanceof EqualTo) {
                return new DiscreteValue(comparisonPredicate.left(), comparisonPredicate, literal);
            }
            RangeValue rangeValue = new RangeValue(comparisonPredicate.left(), comparisonPredicate);
            if (comparisonPredicate instanceof GreaterThanEqual) {
                rangeValue.range = Range.atLeast(literal);
            } else if (comparisonPredicate instanceof GreaterThan) {
                rangeValue.range = Range.greaterThan(literal);
            } else if (comparisonPredicate instanceof LessThanEqual) {
                rangeValue.range = Range.atMost(literal);
            } else if (comparisonPredicate instanceof LessThan) {
                rangeValue.range = Range.lessThan(literal);
            }
            return rangeValue;
        }

        public static ValueDesc discrete(InPredicate inPredicate) {
            Stream<Expression> stream = inPredicate.getOptions().stream();
            Class<Literal> cls = Literal.class;
            Literal.class.getClass();
            return new DiscreteValue(inPredicate.getCompareExpr(), inPredicate, (Set) stream.map((v1) -> {
                return r1.cast(v1);
            }).collect(Collectors.toSet()));
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule, org.apache.doris.nereids.rules.expression.ExpressionRewriteRule
    public Expression rewrite(Expression expression, ExpressionRewriteContext expressionRewriteContext) {
        if (!(expression instanceof CompoundPredicate)) {
            return expression;
        }
        ValueDesc valueDesc = (ValueDesc) expression.accept(new RangeInference(), null);
        Expression expression2 = valueDesc.toExpression();
        return expression2 == null ? valueDesc.expr : expression2;
    }
}
