package org.apache.doris.nereids.util;

import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;

/* loaded from: input_file:org/apache/doris/nereids/util/ExpressionUtils.class */
public class ExpressionUtils {
    public static final List<Expression> EMPTY_CONDITION = ImmutableList.of();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/doris/nereids/util/ExpressionUtils$ExpressionReplacer.class */
    public static class ExpressionReplacer extends DefaultExpressionRewriter<Map<? extends Expression, ? extends Expression>> {
        public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();

        private ExpressionReplacer() {
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter, org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public Expression visit(Expression expression, Map<? extends Expression, ? extends Expression> map) {
            return map.containsKey(expression) ? map.get(expression) : super.visit(expression, (Expression) map);
        }
    }

    public static List<Expression> extractConjunction(Expression expression) {
        return extract(And.class, expression);
    }

    public static Set<Expression> extractConjunctionToSet(Expression expression) {
        HashSet newHashSet = Sets.newHashSet();
        extract(And.class, expression, newHashSet);
        return newHashSet;
    }

    public static List<Expression> extractDisjunction(Expression expression) {
        return extract(Or.class, expression);
    }

    public static List<Expression> extract(CompoundPredicate compoundPredicate) {
        return extract(compoundPredicate.getClass(), compoundPredicate);
    }

    private static List<Expression> extract(Class<? extends Expression> cls, Expression expression) {
        ArrayList newArrayList = Lists.newArrayList();
        extract(cls, expression, newArrayList);
        return newArrayList;
    }

    private static void extract(Class<? extends Expression> cls, Expression expression, Collection<Expression> collection) {
        if (!cls.isInstance(expression)) {
            collection.add(expression);
            return;
        }
        CompoundPredicate compoundPredicate = (CompoundPredicate) expression;
        extract(cls, compoundPredicate.left(), collection);
        extract(cls, compoundPredicate.right(), collection);
    }

    public static Optional<Expression> optionalAnd(List<Expression> list) {
        return list.isEmpty() ? Optional.empty() : Optional.of(and(list));
    }

    public static Optional<Expression> optionalAnd(List<Expression> list, List<Expression> list2) {
        return (list.isEmpty() && list2.isEmpty()) ? Optional.empty() : list.isEmpty() ? optionalAnd(list2) : list2.isEmpty() ? optionalAnd(list) : Optional.of(new And(optionalAnd(list).get(), optionalAnd(list2).get()));
    }

    public static Optional<Expression> optionalAnd(Expression... expressionArr) {
        return optionalAnd((List<Expression>) Lists.newArrayList(expressionArr));
    }

    public static Optional<Expression> optionalAnd(Collection<Expression> collection) {
        return optionalAnd((List<Expression>) ImmutableList.copyOf(collection));
    }

    public static Expression and(Collection<Expression> collection) {
        return combine(And.class, collection);
    }

    public static Expression and(Expression... expressionArr) {
        return combine(And.class, Lists.newArrayList(expressionArr));
    }

    public static Optional<Expression> optionalOr(List<Expression> list) {
        return list.isEmpty() ? Optional.empty() : Optional.of(or(list));
    }

    public static Expression or(Expression... expressionArr) {
        return combine(Or.class, Lists.newArrayList(expressionArr));
    }

    public static Expression or(Collection<Expression> collection) {
        return combine(Or.class, collection);
    }

    public static Expression combine(Class<? extends Expression> cls, Collection<Expression> collection) {
        Preconditions.checkArgument(cls == And.class || cls == Or.class);
        Objects.requireNonNull(collection, "expressions is null");
        BooleanLiteral booleanLiteral = cls == And.class ? BooleanLiteral.FALSE : BooleanLiteral.TRUE;
        BooleanLiteral booleanLiteral2 = cls == And.class ? BooleanLiteral.TRUE : BooleanLiteral.FALSE;
        LinkedHashSet newLinkedHashSetWithExpectedSize = Sets.newLinkedHashSetWithExpectedSize(collection.size());
        for (Expression expression : collection) {
            if (expression.equals(booleanLiteral)) {
                return booleanLiteral;
            }
            if (!expression.equals(booleanLiteral2)) {
                newLinkedHashSetWithExpectedSize.add(expression);
            }
        }
        return (Expression) newLinkedHashSetWithExpectedSize.stream().reduce(cls == And.class ? And::new : Or::new).orElse(BooleanLiteral.of(cls == And.class));
    }

    public static <S extends NamedExpression> S selectMinimumColumn(Collection<S> collection) {
        Preconditions.checkArgument(!collection.isEmpty());
        S s = null;
        for (S s2 : collection) {
            if (s == null) {
                s = s2;
            } else {
                int width = s2.getDataType().width();
                if (width >= 0) {
                    s = (width < s.getDataType().width() || s.getDataType().width() <= 0) ? s2 : s;
                }
            }
        }
        return s;
    }

    public static Optional<ExprId> isSlotOrCastOnSlot(Expression expression) {
        return extractSlotOrCastOnSlot(expression).map((v0) -> {
            return v0.getExprId();
        });
    }

    public static Optional<Slot> extractSlotOrCastOnSlot(Expression expression) {
        while (expression instanceof Cast) {
            expression = expression.child(0);
        }
        return expression instanceof SlotReference ? Optional.of((Slot) expression) : Optional.empty();
    }

    public static Expression replace(Expression expression, Map<? extends Expression, ? extends Expression> map) {
        return (Expression) expression.accept(ExpressionReplacer.INSTANCE, map);
    }

    public static List<Expression> replace(List<Expression> list, Map<? extends Expression, ? extends Expression> map) {
        return (List) list.stream().map(expression -> {
            return replace(expression, (Map<? extends Expression, ? extends Expression>) map);
        }).collect(ImmutableList.toImmutableList());
    }

    public static Set<Expression> replace(Set<Expression> set, Map<? extends Expression, ? extends Expression> map) {
        return (Set) set.stream().map(expression -> {
            return replace(expression, (Map<? extends Expression, ? extends Expression>) map);
        }).collect(ImmutableSet.toImmutableSet());
    }

    public static <E extends Expression> List<E> rewriteDownShortCircuit(Collection<E> collection, Function<Expression, Expression> function) {
        return (List) collection.stream().map(expression -> {
            return expression.rewriteDownShortCircuit(function);
        }).collect(ImmutableList.toImmutableList());
    }

    public static List<Expression> mergeArguments(Object... objArr) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (Object obj : objArr) {
            if (obj instanceof Expression[]) {
                builder.addAll(Arrays.asList((Expression[]) obj));
            } else {
                builder.add((Expression) obj);
            }
        }
        return builder.build();
    }

    public static boolean isAllLiteral(List<Expression> list) {
        return list.stream().allMatch(expression -> {
            return expression instanceof Literal;
        });
    }

    public static boolean matchNumericType(List<Expression> list) {
        return list.stream().allMatch(expression -> {
            return expression.getDataType().isNumericType();
        });
    }

    public static boolean hasNullLiteral(List<Expression> list) {
        return list.stream().anyMatch(expression -> {
            return expression instanceof NullLiteral;
        });
    }

    public static boolean hasOnlyMetricType(List<Expression> list) {
        return list.stream().anyMatch(expression -> {
            return expression.getDataType().isOnlyMetricType();
        });
    }

    public static boolean isAllNullLiteral(List<Expression> list) {
        return list.stream().allMatch(expression -> {
            return expression instanceof NullLiteral;
        });
    }

    public static Set<Slot> inferNotNullSlots(Set<Expression> set, CascadesContext cascadesContext) {
        HashSet newHashSet = Sets.newHashSet();
        for (Expression expression : set) {
            for (Slot slot : expression.getInputSlots()) {
                HashMap hashMap = new HashMap();
                hashMap.put(slot, new NullLiteral(slot.getDataType()));
                Expression rewrite = FoldConstantRule.INSTANCE.rewrite(replace(expression, hashMap), new ExpressionRewriteContext(cascadesContext));
                if (rewrite.isNullLiteral() || BooleanLiteral.FALSE.equals(rewrite)) {
                    newHashSet.add(slot);
                }
            }
        }
        return newHashSet;
    }

    public static Set<Expression> inferNotNull(Set<Expression> set, CascadesContext cascadesContext) {
        return (Set) inferNotNullSlots(set, cascadesContext).stream().map(slot -> {
            Not not = new Not(new IsNull(slot));
            not.isGeneratedIsNotNull = true;
            return not;
        }).collect(Collectors.toSet());
    }

    public static Set<Expression> inferNotNull(Set<Expression> set, Set<Slot> set2, CascadesContext cascadesContext) {
        Stream<Slot> stream = inferNotNullSlots(set, cascadesContext).stream();
        set2.getClass();
        return (Set) stream.filter((v1) -> {
            return r1.contains(v1);
        }).map(slot -> {
            Not not = new Not(new IsNull(slot));
            not.isGeneratedIsNotNull = true;
            return not;
        }).collect(Collectors.toSet());
    }

    public static <E extends Expression> List<E> flatExpressions(List<List<E>> list) {
        return (List) list.stream().flatMap((v0) -> {
            return v0.stream();
        }).collect(ImmutableList.toImmutableList());
    }

    public static boolean anyMatch(List<? extends Expression> list, Predicate<TreeNode<Expression>> predicate) {
        return list.stream().anyMatch(expression -> {
            return expression.anyMatch(predicate);
        });
    }

    public static boolean noneMatch(List<? extends Expression> list, Predicate<TreeNode<Expression>> predicate) {
        return list.stream().noneMatch(expression -> {
            return expression.anyMatch(predicate);
        });
    }

    public static boolean containsType(List<? extends Expression> list, Class cls) {
        cls.getClass();
        return anyMatch(list, (v1) -> {
            return r1.isInstance(v1);
        });
    }

    public static <E> Set<E> collect(List<? extends Expression> list, Predicate<TreeNode<Expression>> predicate) {
        return (Set) list.stream().flatMap(expression -> {
            return ((Set) expression.collect(predicate)).stream();
        }).collect(ImmutableSet.toImmutableSet());
    }

    public static <E> Set<E> mutableCollect(List<? extends Expression> list, Predicate<TreeNode<Expression>> predicate) {
        return (Set) list.stream().flatMap(expression -> {
            return ((Set) expression.collect(predicate)).stream();
        }).collect(Collectors.toSet());
    }

    public static <E> List<E> collectAll(List<? extends Expression> list, Predicate<TreeNode<Expression>> predicate) {
        return (List) list.stream().flatMap(expression -> {
            return ((Set) expression.collect(predicate)).stream();
        }).collect(ImmutableList.toImmutableList());
    }

    public static List<List<Expression>> rollupToGroupingSets(List<Expression> list) {
        ArrayList newArrayList = Lists.newArrayList();
        for (int size = list.size(); size >= 0; size--) {
            newArrayList.add(list.subList(0, size));
        }
        return newArrayList;
    }

    public static Optional<Expression> checkAndMaybeCommute(Expression expression) {
        if (expression instanceof Not) {
            return Optional.empty();
        }
        if (expression instanceof InPredicate) {
            InPredicate inPredicate = (InPredicate) expression;
            if (inPredicate.getCompareExpr().isSlot()) {
                return Optional.ofNullable(inPredicate.getOptions().stream().allMatch((v0) -> {
                    return v0.isLiteral();
                }) ? expression : null);
            }
            return Optional.empty();
        }
        if (!(expression instanceof ComparisonPredicate)) {
            if (expression instanceof IsNull) {
                return Optional.ofNullable(((IsNull) expression).child().isSlot() ? expression : null);
            }
            return Optional.empty();
        }
        ComparisonPredicate comparisonPredicate = (ComparisonPredicate) expression;
        if (comparisonPredicate.left() instanceof Literal) {
            comparisonPredicate = comparisonPredicate.commute();
        }
        return Optional.ofNullable((comparisonPredicate.left().isSlot() && comparisonPredicate.right().isLiteral()) ? comparisonPredicate : null);
    }

    public static List<List<Expression>> cubeToGroupingSets(List<Expression> list) {
        ArrayList newArrayList = Lists.newArrayList();
        cubeToGroupingSets(list, 0, Lists.newArrayList(), newArrayList);
        return newArrayList;
    }

    private static void cubeToGroupingSets(List<Expression> list, int i, List<Expression> list2, List<List<Expression>> list3) {
        if (i == list.size()) {
            list3.add(list2);
            return;
        }
        ArrayList newArrayList = Lists.newArrayList(list2);
        newArrayList.add(list.get(i));
        cubeToGroupingSets(list, i + 1, newArrayList, list3);
        cubeToGroupingSets(list, i + 1, list2, list3);
    }

    public static Set<Slot> getInputSlotSet(Collection<? extends Expression> collection) {
        return (Set) collection.stream().flatMap(expression -> {
            return expression.getInputSlots().stream();
        }).collect(ImmutableSet.toImmutableSet());
    }

    public static boolean checkTypeSkipCast(Expression expression, Class<? extends Expression> cls) {
        while (expression instanceof Cast) {
            expression = ((Cast) expression).child();
        }
        return cls.isInstance(expression);
    }

    public static Expression getExpressionCoveredByCast(Expression expression) {
        while (expression instanceof Cast) {
            expression = ((Cast) expression).child();
        }
        return expression;
    }

    public static boolean checkSlotConstant(Slot slot, Set<Expression> set) {
        return set.stream().anyMatch(expression -> {
            if (!(expression instanceof EqualTo)) {
                return false;
            }
            EqualTo equalTo = (EqualTo) expression;
            return ((equalTo.left() instanceof Literal) && equalTo.right().equals(slot)) || ((equalTo.right() instanceof Literal) && equalTo.left().equals(slot));
        });
    }
}
