package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.util.ExpressionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.class */
public class PushProjectThroughUnion extends OneRewriteRuleFactory {
    @Override // org.apache.doris.nereids.rules.OneRuleFactory
    public Rule build() {
        return logicalProject(logicalSetOperation()).when(logicalProject -> {
            return logicalProject.getProjects().size() == ((LogicalSetOperation) logicalProject.child()).getOutput().size() && logicalProject.getProjects().stream().allMatch(namedExpression -> {
                if (namedExpression instanceof SlotReference) {
                    return true;
                }
                return ExpressionUtils.getExpressionCoveredByCast(namedExpression.child(0)) instanceof SlotReference;
            });
        }).then(logicalProject2 -> {
            LogicalSetOperation logicalSetOperation = (LogicalSetOperation) logicalProject2.child();
            ImmutableList.Builder builder = ImmutableList.builder();
            ImmutableList.Builder builder2 = ImmutableList.builder();
            for (int i = 0; i < logicalSetOperation.arity(); i++) {
                Plan child = logicalSetOperation.child(i);
                HashMap newHashMap = Maps.newHashMap();
                for (int i2 = 0; i2 < logicalSetOperation.getOutput().size(); i2++) {
                    newHashMap.put(logicalSetOperation.getOutput().get(i2), logicalSetOperation.getRegularChildOutput(i).get(i2));
                }
                List list = (List) logicalProject2.getProjects().stream().map(namedExpression -> {
                    return (NamedExpression) ExpressionUtils.replace(namedExpression, (Map<? extends Expression, ? extends Expression>) newHashMap);
                }).map(namedExpression2 -> {
                    return namedExpression2 instanceof Alias ? new Alias(((Alias) namedExpression2).child(), namedExpression2.getName()) : namedExpression2;
                }).collect(ImmutableList.toImmutableList());
                builder.add(new LogicalProject(list, child));
                Stream map = list.stream().map((v0) -> {
                    return v0.toSlot();
                });
                Class<SlotReference> cls = SlotReference.class;
                SlotReference.class.getClass();
                builder2.add(map.map((v1) -> {
                    return r2.cast(v1);
                }).collect(ImmutableList.toImmutableList()));
            }
            return logicalSetOperation.withNewOutputs(logicalProject2.getOutput()).withChildrenAndTheirOutputs(builder.build(), builder2.build());
        }).toRule(RuleType.PUSH_PROJECT_THROUGH_UNION);
    }
}
