package org.apache.doris.nereids.jobs.joinorder;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.Set;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.GraphSimplifier;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.SubgraphEnumerator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.PlanReceiver;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

/* loaded from: input_file:org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.class */
public class JoinOrderJob extends Job {
    private final Group group;
    private final Set<NamedExpression> otherProject;

    public JoinOrderJob(Group group, JobContext jobContext) {
        super(JobType.JOIN_ORDER, jobContext);
        this.otherProject = new HashSet();
        this.group = group;
    }

    @Override // org.apache.doris.nereids.jobs.Job
    public void execute() throws AnalysisException {
        GroupExpression logicalExpression = this.group.getLogicalExpression();
        int arity = logicalExpression.arity();
        for (int i = 0; i < arity; i++) {
            logicalExpression.setChild(i, optimizePlan(logicalExpression.child(i)));
        }
        CascadesContext cascadesContext = this.context.getCascadesContext();
        cascadesContext.pushJob(new DeriveStatsJob(this.group.getLogicalExpression(), cascadesContext.getCurrentJobContext()));
        cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
    }

    private Group optimizePlan(Group group) {
        if (group.isValidJoinGroup()) {
            return optimizeJoin(group);
        }
        GroupExpression logicalExpression = group.getLogicalExpression();
        int arity = logicalExpression.arity();
        for (int i = 0; i < arity; i++) {
            logicalExpression.setChild(i, optimizePlan(logicalExpression.child(i)));
        }
        return group;
    }

    private Group optimizeJoin(Group group) {
        HyperGraph hyperGraph = new HyperGraph();
        buildGraph(group, hyperGraph);
        PlanReceiver planReceiver = new PlanReceiver(this.context, 1000, hyperGraph, group.getLogicalProperties().getOutputSet());
        SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(planReceiver, hyperGraph);
        if (!subgraphEnumerator.enumerate()) {
            new GraphSimplifier(hyperGraph).simplifyGraph(1000);
            if (!subgraphEnumerator.enumerate()) {
                throw new RuntimeException("DPHyp can not enumerate all sub graphs with limit=1000");
            }
        }
        Group bestPlan = planReceiver.getBestPlan(hyperGraph.getNodesMap());
        if (this.otherProject.size() != 0) {
            this.otherProject.addAll(bestPlan.getLogicalExpression().getPlan().getOutput());
            bestPlan = this.context.getCascadesContext().getMemo().copyInGroupExpression(new GroupExpression(new LogicalProject(new ArrayList(this.otherProject), bestPlan.getLogicalExpression().getPlan()), Lists.newArrayList(new Group[]{group})));
        }
        return bestPlan;
    }

    public BitSet buildGraph(Group group, HyperGraph hyperGraph) {
        if (group.isProjectGroup()) {
            BitSet buildGraph = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
            processProjectPlan(hyperGraph, group);
            return buildGraph;
        }
        if (group.isValidJoinGroup()) {
            return hyperGraph.addEdge(group, buildGraph(group.getLogicalExpression().child(0), hyperGraph), buildGraph(group.getLogicalExpression().child(1), hyperGraph));
        }
        hyperGraph.addNode(optimizePlan(group));
        return new BitSet();
    }

    private void processProjectPlan(HyperGraph hyperGraph, Group group) {
        for (NamedExpression namedExpression : ((LogicalProject) group.getLogicalExpression().getPlan()).getProjects()) {
            if (namedExpression instanceof Alias) {
                hyperGraph.addAlias((Alias) namedExpression);
            } else if (!namedExpression.isSlot()) {
                this.otherProject.add(namedExpression);
            }
        }
    }
}
