package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.Sets;
import java.util.HashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.types.coercion.IntegralType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/PredicatePropagation.class */
public class PredicatePropagation {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/PredicatePropagation$ComparisonInferInfo.class */
    public class ComparisonInferInfo {
        public final InferType inferType;
        public final Optional<Expression> left;
        public final Optional<Expression> right;
        public final ComparisonPredicate comparisonPredicate;

        public ComparisonInferInfo(InferType inferType, Optional<Expression> optional, Optional<Expression> optional2, ComparisonPredicate comparisonPredicate) {
            this.inferType = inferType;
            this.left = optional;
            this.right = optional2;
            this.comparisonPredicate = comparisonPredicate;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/PredicatePropagation$InferType.class */
    public enum InferType {
        NONE(null),
        INTEGRAL(IntegralType.class),
        STRING(CharacterType.class),
        DATE(DateLikeType.class),
        OTHER(DataType.class);

        private final Class<? extends DataType> superClazz;

        InferType(Class cls) {
            this.superClazz = cls;
        }
    }

    public Set<Expression> infer(Set<Expression> set) {
        HashSet newHashSet = Sets.newHashSet();
        for (Expression expression : set) {
            if (expression instanceof ComparisonPredicate) {
                ComparisonInferInfo equivalentInferInfo = getEquivalentInferInfo((ComparisonPredicate) expression);
                if (equivalentInferInfo.inferType != InferType.NONE) {
                    Stream<Expression> stream = set.stream();
                    Class<ComparisonPredicate> cls = ComparisonPredicate.class;
                    ComparisonPredicate.class.getClass();
                    Stream<Expression> filter = stream.filter((v1) -> {
                        return r1.isInstance(v1);
                    }).filter(expression2 -> {
                        return !expression2.equals(expression);
                    });
                    Class<ComparisonPredicate> cls2 = ComparisonPredicate.class;
                    ComparisonPredicate.class.getClass();
                    newHashSet.addAll((Set) filter.map((v1) -> {
                        return r1.cast(v1);
                    }).map(this::inferInferInfo).filter(comparisonInferInfo -> {
                        return comparisonInferInfo.inferType != InferType.NONE;
                    }).map(comparisonInferInfo2 -> {
                        return doInfer(equivalentInferInfo, comparisonInferInfo2);
                    }).filter((v0) -> {
                        return Objects.nonNull(v0);
                    }).collect(Collectors.toSet()));
                }
            }
        }
        newHashSet.removeAll(set);
        return newHashSet;
    }

    private Expression doInfer(ComparisonInferInfo comparisonInferInfo, ComparisonInferInfo comparisonInferInfo2) {
        Expression expression = comparisonInferInfo2.left.get();
        Expression expression2 = comparisonInferInfo2.right.get();
        Expression expression3 = comparisonInferInfo.left.get();
        Expression expression4 = comparisonInferInfo.right.get();
        Expression inferOneSide = inferOneSide(expression, expression3, expression4);
        Expression inferOneSide2 = inferOneSide(expression2, expression3, expression4);
        if (inferOneSide == null || inferOneSide2 == null) {
            return null;
        }
        return SimplifyComparisonPredicate.INSTANCE.rewrite(TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) comparisonInferInfo2.comparisonPredicate.withChildren(inferOneSide, inferOneSide2), inferOneSide, inferOneSide2), (ExpressionRewriteContext) null);
    }

    private Expression inferOneSide(Expression expression, Expression expression2, Expression expression3) {
        if (!(expression instanceof SlotReference)) {
            if (expression.isConstant()) {
                return expression instanceof IntegerLikeLiteral ? new NereidsParser().parseExpression(((IntegerLikeLiteral) expression).toSql()) : expression;
            }
            return null;
        }
        if (expression.equals(expression2)) {
            return expression3;
        }
        if (expression.equals(expression3)) {
            return expression2;
        }
        return null;
    }

    private Optional<Expression> validForInfer(Expression expression, InferType inferType) {
        if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
            return Optional.empty();
        }
        if ((expression instanceof SlotReference) || expression.isConstant()) {
            return Optional.of(expression);
        }
        if (inferType == InferType.INTEGRAL) {
            if (expression instanceof Cast) {
                return validForInfer(((Cast) expression).child(), inferType);
            }
        } else if (inferType == InferType.DATE) {
            if (expression instanceof Cast) {
                DataType dataType = expression.getDataType();
                DataType dataType2 = ((Cast) expression).child().getDataType();
                if (dataType instanceof DateType) {
                    if ((dataType2 instanceof DateV2Type) || (dataType2 instanceof DateType)) {
                        return validForInfer(((Cast) expression).child(), inferType);
                    }
                } else if (dataType instanceof DateV2Type) {
                    if ((dataType2 instanceof DateType) || (dataType2 instanceof DateV2Type)) {
                        return validForInfer(((Cast) expression).child(), inferType);
                    }
                } else if (dataType instanceof DateTimeType) {
                    if (!(dataType2 instanceof DateTimeV2Type)) {
                        return validForInfer(((Cast) expression).child(), inferType);
                    }
                } else if (dataType instanceof DateTimeV2Type) {
                    return validForInfer(((Cast) expression).child(), inferType);
                }
            }
        } else {
            if (inferType != InferType.STRING) {
                return Optional.empty();
            }
            if (expression instanceof Cast) {
                DataType dataType3 = expression.getDataType();
                DataType dataType4 = ((Cast) expression).child().getDataType();
                if (dataType3.width() <= 0 || (dataType3.width() >= dataType4.width() && dataType4.width() >= 0)) {
                    return validForInfer(((Cast) expression).child(), inferType);
                }
            }
        }
        return Optional.empty();
    }

    private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
        DataType dataType = comparisonPredicate.left().getDataType();
        InferType inferType = dataType instanceof CharacterType ? InferType.STRING : dataType instanceof IntegralType ? InferType.INTEGRAL : dataType instanceof DateLikeType ? InferType.DATE : InferType.OTHER;
        Optional<Expression> validForInfer = validForInfer(comparisonPredicate.left(), inferType);
        Optional<Expression> validForInfer2 = validForInfer(comparisonPredicate.right(), inferType);
        if (!validForInfer.isPresent() || !validForInfer2.isPresent()) {
            inferType = InferType.NONE;
        }
        return new ComparisonInferInfo(inferType, validForInfer, validForInfer2, comparisonPredicate);
    }

    private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate comparisonPredicate) {
        if (!(comparisonPredicate instanceof EqualTo)) {
            return new ComparisonInferInfo(InferType.NONE, Optional.of(comparisonPredicate.left()), Optional.of(comparisonPredicate.right()), comparisonPredicate);
        }
        ComparisonInferInfo inferInferInfo = inferInferInfo(comparisonPredicate);
        return inferInferInfo.inferType == InferType.NONE ? inferInferInfo : ((inferInferInfo.left.get() instanceof SlotReference) && (inferInferInfo.right.get() instanceof SlotReference)) ? inferInferInfo : new ComparisonInferInfo(InferType.NONE, inferInferInfo.left, inferInferInfo.right, inferInferInfo.comparisonPredicate);
    }
}
