package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.cascades.CostAndEnforcerJob;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.PlanUtils;

/* loaded from: input_file:org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.class */
public class PlanReceiver implements AbstractReceiver {
    int limit;
    JobContext jobContext;
    HyperGraph hyperGraph;
    final Set<Slot> finalOutputs;
    HashMap<Long, Group> planTable = new HashMap<>();
    HashMap<Long, BitSet> usdEdges = new HashMap<>();
    HashMap<Long, List<NamedExpression>> complexProjectMap = new HashMap<>();
    int emitCount = 0;

    public PlanReceiver(JobContext jobContext, int i, HyperGraph hyperGraph, Set<Slot> set) {
        this.jobContext = jobContext;
        this.limit = i;
        this.hyperGraph = hyperGraph;
        this.finalOutputs = set;
    }

    @Override // org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.AbstractReceiver
    public boolean emitCsgCmp(long j, long j2, List<Edge> list) {
        Preconditions.checkArgument(this.planTable.containsKey(Long.valueOf(j)));
        Preconditions.checkArgument(this.planTable.containsKey(Long.valueOf(j2)));
        processMissedEdges(j, j2, list);
        Memo memo = this.jobContext.getCascadesContext().getMemo();
        this.emitCount++;
        if (this.emitCount > this.limit) {
            return false;
        }
        GroupPlan groupPlan = new GroupPlan(this.planTable.get(Long.valueOf(j)));
        GroupPlan groupPlan2 = new GroupPlan(this.planTable.get(Long.valueOf(j2)));
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        JoinType extractJoinTypeAndConjuncts = extractJoinTypeAndConjuncts(list, arrayList, arrayList2);
        if (extractJoinTypeAndConjuncts == null) {
            return true;
        }
        long newBitmapUnion = LongBitmap.newBitmapUnion(j, j2);
        List<Plan> proposeProject = proposeProject(proposeAllPhysicalJoins(extractJoinTypeAndConjuncts, groupPlan, groupPlan2, arrayList, arrayList2), list, j, j2);
        if (!this.planTable.containsKey(Long.valueOf(newBitmapUnion))) {
            this.planTable.put(Long.valueOf(newBitmapUnion), memo.newGroup(proposeProject.get(0).getLogicalProperties()));
        }
        Group group = this.planTable.get(Long.valueOf(newBitmapUnion));
        Iterator<Plan> it = proposeProject.iterator();
        while (it.hasNext()) {
            proposeAllDistributedPlans(memo.copyIn(it.next(), group, false, this.planTable).correspondingExpression);
        }
        return true;
    }

    private Set<Slot> calculateRequiredSlots(long j, long j2, List<Edge> list) {
        HashSet hashSet = new HashSet(this.finalOutputs);
        BitSet bitSet = new BitSet();
        bitSet.or(this.usdEdges.get(Long.valueOf(j)));
        bitSet.or(this.usdEdges.get(Long.valueOf(j2)));
        Iterator<Edge> it = list.iterator();
        while (it.hasNext()) {
            bitSet.set(it.next().getIndex());
        }
        this.usdEdges.put(Long.valueOf(LongBitmap.newBitmapUnion(j, j2)), bitSet);
        for (Edge edge : this.hyperGraph.getEdges()) {
            if (!bitSet.get(edge.getIndex())) {
                hashSet.addAll(edge.getInputSlots());
            }
        }
        long newBitmapUnion = LongBitmap.newBitmapUnion(j, j2);
        this.hyperGraph.getComplexProject().entrySet().stream().filter(entry -> {
            return !LongBitmap.isSubset(((Long) entry.getKey()).longValue(), newBitmapUnion);
        }).flatMap(entry2 -> {
            return ((List) entry2.getValue()).stream();
        }).forEach(namedExpression -> {
            hashSet.addAll(namedExpression.getInputSlots());
        });
        return hashSet;
    }

    private void processMissedEdges(long j, long j2, List<Edge> list) {
        BitSet bitSet = new BitSet();
        bitSet.or(this.usdEdges.get(Long.valueOf(j)));
        bitSet.or(this.usdEdges.get(Long.valueOf(j2)));
        list.forEach(edge -> {
            bitSet.set(edge.getIndex());
        });
        long or = LongBitmap.or(j, j2);
        for (Edge edge2 : this.hyperGraph.getEdges()) {
            if (LongBitmap.isSubset(LongBitmap.newBitmapUnion(edge2.getOriginalLeft(), edge2.getOriginalRight()), or) && !bitSet.get(edge2.getIndex())) {
                list.add(edge2);
            }
        }
    }

    private void proposeAllDistributedPlans(GroupExpression groupExpression) {
        this.jobContext.getCascadesContext().pushJob(new CostAndEnforcerJob(groupExpression, new JobContext(this.jobContext.getCascadesContext(), PhysicalProperties.ANY, Double.MAX_VALUE)));
        if (!groupExpression.isStatDerived()) {
            this.jobContext.getCascadesContext().pushJob(new DeriveStatsJob(groupExpression, this.jobContext.getCascadesContext().getCurrentJobContext()));
        }
        this.jobContext.getCascadesContext().getJobScheduler().executeJobPool(this.jobContext.getCascadesContext());
    }

    private List<Plan> proposeAllPhysicalJoins(JoinType joinType, Plan plan, Plan plan2, List<Expression> list, List<Expression> list2) {
        LogicalProperties logicalProperties = new LogicalProperties(() -> {
            return JoinUtils.getJoinOutput(joinType, plan, plan2);
        });
        ArrayList newArrayList = Lists.newArrayList();
        if (JoinUtils.shouldNestedLoopJoin(joinType, list)) {
            newArrayList.add(new PhysicalNestedLoopJoin(joinType, list, list2, Optional.empty(), logicalProperties, plan, plan2));
            if (joinType.isSwapJoinType()) {
                newArrayList.add(new PhysicalNestedLoopJoin(joinType.swap(), list, list2, Optional.empty(), logicalProperties, plan2, plan));
            }
        } else {
            newArrayList.add(new PhysicalHashJoin(joinType, list, list2, JoinHint.NONE, Optional.empty(), logicalProperties, plan, plan2));
            if (joinType.isSwapJoinType()) {
                newArrayList.add(new PhysicalHashJoin(joinType.swap(), list, list2, JoinHint.NONE, Optional.empty(), logicalProperties, plan2, plan));
            }
        }
        return newArrayList;
    }

    @Nullable
    private JoinType extractJoinTypeAndConjuncts(List<Edge> list, List<Expression> list2, List<Expression> list3) {
        JoinType joinType = null;
        for (Edge edge : list) {
            if (edge.getJoinType() != joinType && joinType != null) {
                return null;
            }
            Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
            joinType = edge.getJoinType();
            list2.addAll(edge.getHashJoinConjuncts());
            list3.addAll(edge.getOtherJoinConjuncts());
        }
        return joinType;
    }

    private boolean extractIsMarkJoin(List<Edge> list) {
        boolean z = false;
        JoinType joinType = null;
        for (Edge edge : list) {
            Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
            z = edge.getJoin().isMarkJoin() || z;
            joinType = edge.getJoinType();
        }
        return z;
    }

    @Override // org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.AbstractReceiver
    public void addGroup(long j, Group group) {
        Preconditions.checkArgument(LongBitmap.getCardinality(j) == 1);
        this.usdEdges.put(Long.valueOf(j), new BitSet());
        Plan plan = proposeProject(Lists.newArrayList(new Plan[]{new GroupPlan(group)}), new ArrayList(), j, j).get(0);
        if (!(plan instanceof GroupPlan)) {
            group = this.jobContext.getCascadesContext().getMemo().copyIn(plan, null, false, this.planTable).correspondingExpression.getOwnerGroup();
        }
        this.planTable.put(Long.valueOf(j), group);
        this.usdEdges.put(Long.valueOf(j), new BitSet());
    }

    @Override // org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.AbstractReceiver
    public boolean contain(long j) {
        return this.planTable.containsKey(Long.valueOf(j));
    }

    @Override // org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.AbstractReceiver
    public void reset() {
        this.emitCount = 0;
        this.planTable.clear();
        this.usdEdges.clear();
        this.complexProjectMap.clear();
        this.complexProjectMap.putAll(this.hyperGraph.getComplexProject());
    }

    @Override // org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.AbstractReceiver
    public Group getBestPlan(long j) {
        makeLogicalExpression(() -> {
            return this.planTable.get(Long.valueOf(j));
        });
        return this.planTable.get(Long.valueOf(j));
    }

    private void makeLogicalExpression(Supplier<Group> supplier) {
        Plan logicalJoin;
        if (supplier.get().getLogicalExpressions().isEmpty()) {
            HashSet hashSet = new HashSet();
            Iterator<PhysicalProperties> it = supplier.get().getAllProperties().iterator();
            while (it.hasNext()) {
                GroupExpression bestPlan = supplier.get().getBestPlan(it.next());
                if (!hashSet.contains(bestPlan) && !(bestPlan.getPlan() instanceof PhysicalDistribute)) {
                    hashSet.add(bestPlan);
                    Plan plan = bestPlan.getPlan();
                    for (int i = 0; i < bestPlan.children().size(); i++) {
                        int i2 = i;
                        makeLogicalExpression(() -> {
                            return bestPlan.child(i2);
                        });
                    }
                    if (plan instanceof PhysicalProject) {
                        logicalJoin = new LogicalProject(((PhysicalProject) plan).getProjects(), new GroupPlan(bestPlan.child(0)));
                    } else {
                        if (!(plan instanceof AbstractPhysicalJoin)) {
                            throw new RuntimeException("DPhyp can only handle join and project operator");
                        }
                        AbstractPhysicalJoin abstractPhysicalJoin = (AbstractPhysicalJoin) plan;
                        logicalJoin = new LogicalJoin(abstractPhysicalJoin.getJoinType(), abstractPhysicalJoin.getHashJoinConjuncts(), abstractPhysicalJoin.getOtherJoinConjuncts(), JoinHint.NONE, abstractPhysicalJoin.getMarkJoinSlotReference(), (List<Plan>) bestPlan.children().stream().map(group -> {
                            return new GroupPlan(group);
                        }).collect(Collectors.toList()));
                    }
                    this.jobContext.getCascadesContext().getMemo().copyIn(logicalJoin, supplier.get(), false, this.planTable);
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v70, types: [java.util.List] */
    /* JADX WARN: Type inference failed for: r0v72, types: [java.util.List] */
    private List<Plan> proposeProject(List<Plan> list, List<Edge> list2, long j, long j2) {
        long newBitmapUnion = LongBitmap.newBitmapUnion(j, j2);
        List<Slot> output = list.get(0).getOutput();
        Set<Slot> outputSet = list.get(0).getOutputSet();
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList arrayList = new ArrayList();
        Iterator it = ((List) this.complexProjectMap.keySet().stream().filter(l -> {
            return LongBitmap.isSubset(l.longValue(), newBitmapUnion) && (!(LongBitmap.isSubset(l.longValue(), j) || LongBitmap.isSubset(l.longValue(), j2)) || j == j2);
        }).collect(Collectors.toList())).iterator();
        while (it.hasNext()) {
            long longValue = ((Long) it.next()).longValue();
            arrayList = arrayList.isEmpty() ? (List) this.complexProjectMap.get(Long.valueOf(longValue)) : PlanUtils.mergeProjections(arrayList, this.complexProjectMap.get(Long.valueOf(longValue)));
        }
        newArrayList.addAll(arrayList);
        Set<Slot> calculateRequiredSlots = calculateRequiredSlots(j, j2, list2);
        newArrayList.addAll((Collection) output.stream().filter(slot -> {
            return calculateRequiredSlots.contains(slot);
        }).collect(Collectors.toList()));
        if (newArrayList.isEmpty()) {
            newArrayList.add(ExpressionUtils.selectMinimumColumn(output));
        }
        if (outputSet.equals(new HashSet(newArrayList))) {
            return list;
        }
        Set<Slot> outputSet2 = list.get(0).getOutputSet();
        List list3 = (List) newArrayList.stream().filter(namedExpression -> {
            return outputSet2.containsAll(namedExpression.getInputSlots());
        }).collect(Collectors.toList());
        if (!outputSet.equals(new HashSet(list3))) {
            LogicalProperties logicalProperties = new LogicalProperties(() -> {
                return (List) list3.stream().map(namedExpression2 -> {
                    return namedExpression2.toSlot();
                }).collect(Collectors.toList());
            });
            list = (List) list.stream().map(plan -> {
                return new PhysicalProject(list3, logicalProperties, plan);
            }).collect(Collectors.toList());
        }
        Preconditions.checkState(!list3.isEmpty() && list3.size() == newArrayList.size());
        return list;
    }
}
