package org.apache.doris.nereids.rules.expression.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.util.ExpressionUtils;

@Developing
/* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/ExtractCommonFactorRule.class */
public class ExtractCommonFactorRule extends AbstractExpressionRewriteRule {
    public static final ExtractCommonFactorRule INSTANCE = new ExtractCommonFactorRule();

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitCompoundPredicate(CompoundPredicate compoundPredicate, ExpressionRewriteContext expressionRewriteContext) {
        Expression combine = ExpressionUtils.combine(compoundPredicate.getClass(), (Collection) ExpressionUtils.extract(compoundPredicate).stream().map(expression -> {
            return rewrite(expression, expressionRewriteContext);
        }).collect(ImmutableList.toImmutableList()));
        if (!(combine instanceof CompoundPredicate)) {
            return combine;
        }
        CompoundPredicate compoundPredicate2 = (CompoundPredicate) combine;
        List list = (List) ExpressionUtils.extract(compoundPredicate2).stream().map(expression2 -> {
            return expression2 instanceof CompoundPredicate ? ExpressionUtils.extract((CompoundPredicate) expression2) : Lists.newArrayList(new Expression[]{expression2});
        }).collect(Collectors.toList());
        Set set = (Set) list.stream().map((v1) -> {
            return new HashSet(v1);
        }).reduce(Sets::intersection).orElse(Collections.emptySet());
        Expression combine2 = ExpressionUtils.combine(compoundPredicate2.getClass(), (Collection) ((List) list.stream().map(list2 -> {
            return (List) list2.stream().filter(expression3 -> {
                return !set.contains(expression3);
            }).collect(Collectors.toList());
        }).collect(Collectors.toList())).stream().map(list3 -> {
            return ExpressionUtils.combine(compoundPredicate2.flipType(), list3);
        }).collect(Collectors.toList()));
        ArrayList newArrayList = Lists.newArrayList(set);
        newArrayList.add(combine2);
        return ExpressionUtils.combine(compoundPredicate2.flipType(), newArrayList);
    }
}
