package org.apache.doris.nereids.rules.expression;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.util.ExpressionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/expression/ExpressionRewrite.class */
public class ExpressionRewrite implements RewriteRuleFactory {
    private final ExpressionRuleExecutor rewriter;

    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/ExpressionRewrite$AggExpressionRewrite.class */
    private class AggExpressionRewrite extends OneRewriteRuleFactory {
        private AggExpressionRewrite() {
        }

        @Override // org.apache.doris.nereids.rules.OneRuleFactory
        public Rule build() {
            return logicalAggregate().thenApply(matchingContext -> {
                LogicalAggregate logicalAggregate = (LogicalAggregate) matchingContext.root;
                List<Expression> groupByExpressions = logicalAggregate.getGroupByExpressions();
                ExpressionRewriteContext expressionRewriteContext = new ExpressionRewriteContext(matchingContext.cascadesContext);
                List<Expression> rewrite = ExpressionRewrite.this.rewriter.rewrite(groupByExpressions, expressionRewriteContext);
                List<NamedExpression> outputExpressions = logicalAggregate.getOutputExpressions();
                List list = (List) outputExpressions.stream().map(namedExpression -> {
                    return (NamedExpression) ExpressionRewrite.this.rewriter.rewrite(namedExpression, expressionRewriteContext);
                }).collect(ImmutableList.toImmutableList());
                return outputExpressions.equals(list) ? logicalAggregate : new LogicalAggregate(rewrite, list, logicalAggregate.isNormalized(), logicalAggregate.getSourceRepeat(), (Plan) logicalAggregate.child());
            }).toRule(RuleType.REWRITE_AGG_EXPRESSION);
        }
    }

    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/ExpressionRewrite$FilterExpressionRewrite.class */
    private class FilterExpressionRewrite extends OneRewriteRuleFactory {
        private FilterExpressionRewrite() {
        }

        @Override // org.apache.doris.nereids.rules.OneRuleFactory
        public Rule build() {
            return logicalFilter().thenApply(matchingContext -> {
                LogicalFilter logicalFilter = (LogicalFilter) matchingContext.root;
                ImmutableSet copyOf = ImmutableSet.copyOf(ExpressionUtils.extractConjunction(ExpressionRewrite.this.rewriter.rewrite(logicalFilter.getPredicate(), new ExpressionRewriteContext(matchingContext.cascadesContext))));
                return copyOf.equals(logicalFilter.getConjuncts()) ? logicalFilter : new LogicalFilter(copyOf, (Plan) logicalFilter.child());
            }).toRule(RuleType.REWRITE_FILTER_EXPRESSION);
        }
    }

    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/ExpressionRewrite$GenerateExpressionRewrite.class */
    private class GenerateExpressionRewrite extends OneRewriteRuleFactory {
        private GenerateExpressionRewrite() {
        }

        @Override // org.apache.doris.nereids.rules.OneRuleFactory
        public Rule build() {
            return logicalGenerate().thenApply(matchingContext -> {
                LogicalGenerate logicalGenerate = (LogicalGenerate) matchingContext.root;
                ExpressionRewriteContext expressionRewriteContext = new ExpressionRewriteContext(matchingContext.cascadesContext);
                List<Function> generators = logicalGenerate.getGenerators();
                List<Function> list = (List) generators.stream().map(function -> {
                    return (Function) ExpressionRewrite.this.rewriter.rewrite(function, expressionRewriteContext);
                }).collect(ImmutableList.toImmutableList());
                return generators.equals(list) ? logicalGenerate : logicalGenerate.withGenerators(list);
            }).toRule(RuleType.REWRITE_GENERATE_EXPRESSION);
        }
    }

    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/ExpressionRewrite$HavingExpressionRewrite.class */
    private class HavingExpressionRewrite extends OneRewriteRuleFactory {
        private HavingExpressionRewrite() {
        }

        @Override // org.apache.doris.nereids.rules.OneRuleFactory
        public Rule build() {
            return logicalHaving().thenApply(matchingContext -> {
                LogicalHaving logicalHaving = (LogicalHaving) matchingContext.root;
                ImmutableSet copyOf = ImmutableSet.copyOf(ExpressionUtils.extractConjunction(ExpressionRewrite.this.rewriter.rewrite(logicalHaving.getPredicate(), new ExpressionRewriteContext(matchingContext.cascadesContext))));
                return copyOf.equals(logicalHaving.getConjuncts()) ? logicalHaving : logicalHaving.withExpressions(copyOf);
            }).toRule(RuleType.REWRITE_HAVING_EXPRESSION);
        }
    }

    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/ExpressionRewrite$JoinExpressionRewrite.class */
    private class JoinExpressionRewrite extends OneRewriteRuleFactory {
        private JoinExpressionRewrite() {
        }

        @Override // org.apache.doris.nereids.rules.OneRuleFactory
        public Rule build() {
            return logicalJoin().thenApply(matchingContext -> {
                LogicalJoin logicalJoin = (LogicalJoin) matchingContext.root;
                List<Expression> hashJoinConjuncts = logicalJoin.getHashJoinConjuncts();
                List<Expression> otherJoinConjuncts = logicalJoin.getOtherJoinConjuncts();
                if (otherJoinConjuncts.isEmpty() && hashJoinConjuncts.isEmpty()) {
                    return logicalJoin;
                }
                ExpressionRewriteContext expressionRewriteContext = new ExpressionRewriteContext(matchingContext.cascadesContext);
                ArrayList newArrayList = Lists.newArrayList();
                boolean z = false;
                for (Expression expression : hashJoinConjuncts) {
                    Expression rewrite = ExpressionRewrite.this.rewriter.rewrite(expression, expressionRewriteContext);
                    z = z || !rewrite.equals(expression);
                    newArrayList.addAll(ExpressionUtils.extractConjunction(rewrite));
                }
                ArrayList newArrayList2 = Lists.newArrayList();
                boolean z2 = false;
                for (Expression expression2 : otherJoinConjuncts) {
                    Expression rewrite2 = ExpressionRewrite.this.rewriter.rewrite(expression2, expressionRewriteContext);
                    z2 = z2 || !rewrite2.equals(expression2);
                    newArrayList2.addAll(ExpressionUtils.extractConjunction(rewrite2));
                }
                return (z || z2) ? new LogicalJoin(logicalJoin.getJoinType(), newArrayList, newArrayList2, logicalJoin.getHint(), logicalJoin.getMarkJoinSlotReference(), logicalJoin.children()) : logicalJoin;
            }).toRule(RuleType.REWRITE_JOIN_EXPRESSION);
        }
    }

    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/ExpressionRewrite$LogicalRepeatRewrite.class */
    private class LogicalRepeatRewrite extends OneRewriteRuleFactory {
        private LogicalRepeatRewrite() {
        }

        @Override // org.apache.doris.nereids.rules.OneRuleFactory
        public Rule build() {
            return logicalRepeat().thenApply(matchingContext -> {
                LogicalRepeat logicalRepeat = (LogicalRepeat) matchingContext.root;
                ImmutableList.Builder builder = ImmutableList.builder();
                ExpressionRewriteContext expressionRewriteContext = new ExpressionRewriteContext(matchingContext.cascadesContext);
                Iterator<List<Expression>> it = logicalRepeat.getGroupingSets().iterator();
                while (it.hasNext()) {
                    builder.add(it.next().stream().map(expression -> {
                        return ExpressionRewrite.this.rewriter.rewrite(expression, expressionRewriteContext);
                    }).collect(ImmutableList.toImmutableList()));
                }
                return logicalRepeat.withGroupSetsAndOutput(builder.build(), (List) logicalRepeat.getOutputExpressions().stream().map(namedExpression -> {
                    return ExpressionRewrite.this.rewriter.rewrite(namedExpression, expressionRewriteContext);
                }).map(expression2 -> {
                    return (NamedExpression) expression2;
                }).collect(ImmutableList.toImmutableList()));
            }).toRule(RuleType.REWRITE_REPEAT_EXPRESSION);
        }
    }

    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/ExpressionRewrite$OneRowRelationExpressionRewrite.class */
    private class OneRowRelationExpressionRewrite extends OneRewriteRuleFactory {
        private OneRowRelationExpressionRewrite() {
        }

        @Override // org.apache.doris.nereids.rules.OneRuleFactory
        public Rule build() {
            return logicalOneRowRelation().thenApply(matchingContext -> {
                LogicalOneRowRelation logicalOneRowRelation = (LogicalOneRowRelation) matchingContext.root;
                List<NamedExpression> projects = logicalOneRowRelation.getProjects();
                ExpressionRewriteContext expressionRewriteContext = new ExpressionRewriteContext(matchingContext.cascadesContext);
                List list = (List) projects.stream().map(namedExpression -> {
                    return (NamedExpression) ExpressionRewrite.this.rewriter.rewrite(namedExpression, expressionRewriteContext);
                }).collect(ImmutableList.toImmutableList());
                return projects.equals(list) ? logicalOneRowRelation : new LogicalOneRowRelation(logicalOneRowRelation.getRelationId(), list);
            }).toRule(RuleType.REWRITE_ONE_ROW_RELATION_EXPRESSION);
        }
    }

    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/ExpressionRewrite$ProjectExpressionRewrite.class */
    private class ProjectExpressionRewrite extends OneRewriteRuleFactory {
        private ProjectExpressionRewrite() {
        }

        @Override // org.apache.doris.nereids.rules.OneRuleFactory
        public Rule build() {
            return logicalProject().thenApply(matchingContext -> {
                LogicalProject logicalProject = (LogicalProject) matchingContext.root;
                ExpressionRewriteContext expressionRewriteContext = new ExpressionRewriteContext(matchingContext.cascadesContext);
                List<NamedExpression> projects = logicalProject.getProjects();
                List<NamedExpression> list = (List) projects.stream().map(namedExpression -> {
                    return (NamedExpression) ExpressionRewrite.this.rewriter.rewrite(namedExpression, expressionRewriteContext);
                }).collect(ImmutableList.toImmutableList());
                return projects.equals(list) ? logicalProject : logicalProject.withProjectsAndChild(list, (Plan) logicalProject.child());
            }).toRule(RuleType.REWRITE_PROJECT_EXPRESSION);
        }
    }

    /* loaded from: input_file:org/apache/doris/nereids/rules/expression/ExpressionRewrite$SortExpressionRewrite.class */
    private class SortExpressionRewrite extends OneRewriteRuleFactory {
        private SortExpressionRewrite() {
        }

        @Override // org.apache.doris.nereids.rules.OneRuleFactory
        public Rule build() {
            return logicalSort().thenApply(matchingContext -> {
                LogicalSort logicalSort = (LogicalSort) matchingContext.root;
                List<OrderKey> orderKeys = logicalSort.getOrderKeys();
                ArrayList arrayList = new ArrayList();
                ExpressionRewriteContext expressionRewriteContext = new ExpressionRewriteContext(matchingContext.cascadesContext);
                for (OrderKey orderKey : orderKeys) {
                    arrayList.add(new OrderKey(ExpressionRewrite.this.rewriter.rewrite(orderKey.getExpr(), expressionRewriteContext), orderKey.isAsc(), orderKey.isNullFirst()));
                }
                return logicalSort.withOrderKeys(arrayList);
            }).toRule(RuleType.REWRITE_SORT_EXPRESSION);
        }
    }

    public ExpressionRewrite(ExpressionRewriteRule... expressionRewriteRuleArr) {
        this.rewriter = new ExpressionRuleExecutor(ImmutableList.copyOf(expressionRewriteRuleArr));
    }

    public ExpressionRewrite(ExpressionRuleExecutor expressionRuleExecutor) {
        this.rewriter = (ExpressionRuleExecutor) Objects.requireNonNull(expressionRuleExecutor, "rewriter is null");
    }

    public Expression rewrite(Expression expression, ExpressionRewriteContext expressionRewriteContext) {
        return this.rewriter.rewrite(expression, expressionRewriteContext);
    }

    @Override // org.apache.doris.nereids.rules.RuleFactory
    public List<Rule> buildRules() {
        return ImmutableList.of(new GenerateExpressionRewrite().build(), new OneRowRelationExpressionRewrite().build(), new ProjectExpressionRewrite().build(), new AggExpressionRewrite().build(), new FilterExpressionRewrite().build(), new JoinExpressionRewrite().build(), new SortExpressionRewrite().build(), new LogicalRepeatRewrite().build(), new HavingExpressionRewrite().build());
    }
}
