package org.apache.doris.nereids.jobs.joinorder.hypergraph;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.PlanUtils;

/* loaded from: input_file:org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.class */
public class HyperGraph {
    private final List<Edge> edges = new ArrayList();
    private final List<Node> nodes = new ArrayList();
    private final HashSet<Group> nodeSet = new HashSet<>();
    private final HashMap<Slot, Long> slotToNodeMap = new HashMap<>();
    private final HashMap<Long, List<NamedExpression>> complexProject = new HashMap<>();

    public List<Edge> getEdges() {
        return this.edges;
    }

    public List<Node> getNodes() {
        return this.nodes;
    }

    public long getNodesMap() {
        return LongBitmap.newBitmapBetween(0, this.nodes.size());
    }

    public Edge getEdge(int i) {
        return this.edges.get(i);
    }

    public Node getNode(int i) {
        return this.nodes.get(i);
    }

    public boolean addAlias(Alias alias) {
        Slot slot = alias.toSlot();
        if (this.slotToNodeMap.containsKey(slot)) {
            return true;
        }
        long newBitmap = LongBitmap.newBitmap();
        Iterator<Slot> it = alias.getInputSlots().iterator();
        while (it.hasNext()) {
            newBitmap = LongBitmap.or(newBitmap, this.slotToNodeMap.get(it.next()).longValue());
        }
        this.slotToNodeMap.put(slot, Long.valueOf(newBitmap));
        if (!this.complexProject.containsKey(Long.valueOf(newBitmap))) {
            this.complexProject.put(Long.valueOf(newBitmap), new ArrayList());
        } else if (!(alias.child() instanceof SlotReference)) {
            alias = (Alias) PlanUtils.mergeProjections(this.complexProject.get(Long.valueOf(newBitmap)), Lists.newArrayList(new NamedExpression[]{alias})).get(0);
        }
        this.complexProject.get(Long.valueOf(newBitmap)).add(alias);
        return true;
    }

    public void addNode(Group group) {
        Preconditions.checkArgument(!group.isValidJoinGroup());
        for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) {
            Preconditions.checkArgument(!this.slotToNodeMap.containsKey(slot));
            this.slotToNodeMap.put(slot, Long.valueOf(LongBitmap.newBitmap(this.nodes.size())));
        }
        this.nodeSet.add(group);
        this.nodes.add(new Node(this.nodes.size(), group));
    }

    public boolean isNodeGroup(Group group) {
        return this.nodeSet.contains(group);
    }

    public HashMap<Long, List<NamedExpression>> getComplexProject() {
        return this.complexProject;
    }

    public BitSet addEdge(Group group, BitSet bitSet, BitSet bitSet2) {
        Preconditions.checkArgument(group.isValidJoinGroup());
        LogicalJoin logicalJoin = (LogicalJoin) group.getLogicalExpression().getPlan();
        HashMap hashMap = new HashMap();
        for (Expression expression : logicalJoin.getHashJoinConjuncts()) {
            Pair<Long, Long> findEnds = findEnds(expression);
            if (!hashMap.containsKey(findEnds)) {
                hashMap.put(findEnds, Pair.of(new ArrayList(), new ArrayList()));
            }
            ((List) ((Pair) hashMap.get(findEnds)).first).add(expression);
        }
        for (Expression expression2 : logicalJoin.getOtherJoinConjuncts()) {
            Pair<Long, Long> findEnds2 = findEnds(expression2);
            if (!hashMap.containsKey(findEnds2)) {
                hashMap.put(findEnds2, Pair.of(new ArrayList(), new ArrayList()));
            }
            ((List) ((Pair) hashMap.get(findEnds2)).second).add(expression2);
        }
        BitSet bitSet3 = new BitSet();
        bitSet3.or(bitSet);
        bitSet3.or(bitSet2);
        for (Map.Entry entry : hashMap.entrySet()) {
            Edge edge = new Edge(new LogicalJoin(logicalJoin.getJoinType(), (List<Expression>) ((Pair) entry.getValue()).first, (List<Expression>) ((Pair) entry.getValue()).second, JoinHint.NONE, logicalJoin.getMarkJoinSlotReference(), Lists.newArrayList(new Plan[]{logicalJoin.left(), logicalJoin.right()})), this.edges.size());
            initEdgeEnds((Pair) entry.getKey(), edge, bitSet, bitSet2);
            Iterator<Integer> it = LongBitmap.getIterator(edge.getReferenceNodes()).iterator();
            while (it.hasNext()) {
                this.nodes.get(it.next().intValue()).attachEdge(edge);
            }
            bitSet3.set(edge.getIndex());
            this.edges.add(edge);
        }
        return bitSet3;
    }

    private void initEdgeEnds(Pair<Long, Long> pair, Edge edge, BitSet bitSet, BitSet bitSet2) {
        long longValue = ((Long) pair.first).longValue();
        long longValue2 = ((Long) pair.second).longValue();
        int nextSetBit = bitSet.nextSetBit(0);
        while (true) {
            int i = nextSetBit;
            if (i < 0) {
                break;
            }
            Edge edge2 = this.edges.get(i);
            if (!JoinType.isAssoc(edge2.getJoinType(), edge.getJoinType())) {
                longValue = LongBitmap.or(longValue, edge2.getLeft());
            }
            if (!JoinType.isLAssoc(edge2.getJoinType(), edge.getJoinType())) {
                longValue = LongBitmap.or(longValue, edge2.getRight());
            }
            nextSetBit = bitSet.nextSetBit(i + 1);
        }
        int nextSetBit2 = bitSet2.nextSetBit(0);
        while (true) {
            int i2 = nextSetBit2;
            if (i2 < 0) {
                edge.setOriginalLeft(longValue);
                edge.setOriginalRight(longValue2);
                edge.setLeft(longValue);
                edge.setRight(longValue2);
                return;
            }
            Edge edge3 = this.edges.get(i2);
            if (!JoinType.isAssoc(edge3.getJoinType(), edge.getJoinType())) {
                longValue2 = LongBitmap.or(longValue2, edge3.getRight());
            }
            if (!JoinType.isRAssoc(edge3.getJoinType(), edge.getJoinType())) {
                longValue2 = LongBitmap.or(longValue2, edge3.getLeft());
            }
            nextSetBit2 = bitSet2.nextSetBit(i2 + 1);
        }
    }

    private int findRoot(List<Integer> list, int i) {
        int intValue = list.get(i).intValue();
        if (intValue != i) {
            intValue = findRoot(list, intValue);
        }
        list.set(i, Integer.valueOf(intValue));
        return intValue;
    }

    private boolean isConnected(long j, long j2) {
        if (LongBitmap.getCardinality(j) == 1) {
            return true;
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.nodes.size(); i++) {
            arrayList.add(i, Integer.valueOf(i));
        }
        for (Edge edge : this.edges) {
            if (!LongBitmap.isOverlap(edge.getLeft(), j2) && !LongBitmap.isOverlap(edge.getRight(), j2)) {
                int findRoot = findRoot(arrayList, LongBitmap.nextSetBit(edge.getLeft(), 0));
                Iterator<Integer> it = LongBitmap.getIterator(edge.getLeft()).iterator();
                while (it.hasNext()) {
                    arrayList.set(it.next().intValue(), Integer.valueOf(findRoot));
                }
                Iterator<Integer> it2 = LongBitmap.getIterator(edge.getRight()).iterator();
                while (it2.hasNext()) {
                    arrayList.set(it2.next().intValue(), Integer.valueOf(findRoot));
                }
            }
        }
        int findRoot2 = findRoot(arrayList, LongBitmap.nextSetBit(j, 0));
        Iterator<Integer> it3 = LongBitmap.getIterator(j).iterator();
        while (it3.hasNext()) {
            if (findRoot2 != findRoot(arrayList, it3.next().intValue())) {
                return false;
            }
        }
        return true;
    }

    private Pair<Long, Long> findEnds(Expression expression) {
        long calNodeMap = calNodeMap(expression.getInputSlots());
        Preconditions.checkArgument(LongBitmap.getCardinality(calNodeMap) > 1);
        Iterator<Long> it = LongBitmap.getSubsetIterator(calNodeMap).iterator();
        while (it.hasNext()) {
            long longValue = it.next().longValue();
            long newBitmapDiff = LongBitmap.newBitmapDiff(calNodeMap, longValue);
            if (isConnected(longValue, newBitmapDiff) && isConnected(newBitmapDiff, longValue)) {
                return Pair.of(Long.valueOf(longValue), Long.valueOf(newBitmapDiff));
            }
        }
        throw new RuntimeException("DPhyper meets unconnected subgraph");
    }

    private long calNodeMap(Set<Slot> set) {
        Preconditions.checkArgument(set.size() != 0);
        long newBitmap = LongBitmap.newBitmap();
        for (Slot slot : set) {
            Preconditions.checkArgument(this.slotToNodeMap.containsKey(slot));
            newBitmap = LongBitmap.or(newBitmap, this.slotToNodeMap.get(slot).longValue());
        }
        return newBitmap;
    }

    public void modifyEdge(int i, long j, long j2) {
        Edge edge = this.edges.get(i);
        updateEdges(edge, edge.getLeft(), j);
        updateEdges(edge, edge.getRight(), j2);
        this.edges.get(i).setLeft(j);
        this.edges.get(i).setRight(j2);
    }

    private void updateEdges(Edge edge, long j, long j2) {
        LongBitmap.getIterator(LongBitmap.newBitmapDiff(j, j2)).forEach(num -> {
            this.nodes.get(num.intValue()).removeEdge(edge);
        });
        LongBitmap.getIterator(LongBitmap.newBitmapDiff(j2, j)).forEach(num2 -> {
            this.nodes.get(num2.intValue()).attachEdge(edge);
        });
    }

    public String toDottyHyperGraph() {
        String str;
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("digraph G {  # %d edges\n", Integer.valueOf(this.edges.size())));
        ArrayList arrayList = new ArrayList();
        for (Node node : this.nodes) {
            String name = node.getName();
            String str2 = name;
            while (true) {
                str = str2;
                if (arrayList.contains(str)) {
                    str2 = str + AggStateFunctionBuilder.COMBINATOR_LINKER;
                }
            }
            sb.append(String.format("  %s [label=\"%s \n rowCount=%.2f\"];\n", str, name, Double.valueOf(node.getRowCount())));
            arrayList.add(name);
        }
        for (int i = 0; i < this.edges.size(); i++) {
            Edge edge = this.edges.get(i);
            String format = String.format("%.2f", Double.valueOf(edge.getSelectivity()));
            if (this.edges.get(i).isSimple()) {
                sb.append(String.format("%s -> %s [label=\"%s\"%s]\n", arrayList.get(LongBitmap.lowestOneIndex(edge.getLeft())), arrayList.get(LongBitmap.lowestOneIndex(edge.getRight())), format, edge.getJoin().getJoinType() == JoinType.INNER_JOIN ? ",arrowhead=none" : ""));
            } else {
                sb.append(String.format("e%d [shape=circle, width=.001, label=\"\"]\n", Integer.valueOf(i)));
                String str3 = "";
                String str4 = "";
                if (LongBitmap.getCardinality(edge.getLeft()) == 1) {
                    str4 = format;
                } else {
                    str3 = format;
                }
                int i2 = i;
                String str5 = str3;
                Iterator<Integer> it = LongBitmap.getIterator(edge.getLeft()).iterator();
                while (it.hasNext()) {
                    sb.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n", arrayList.get(it.next().intValue()), Integer.valueOf(i2), str5));
                }
                String str6 = str4;
                Iterator<Integer> it2 = LongBitmap.getIterator(edge.getRight()).iterator();
                while (it2.hasNext()) {
                    sb.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n", arrayList.get(it2.next().intValue()), Integer.valueOf(i2), str6));
                }
            }
        }
        sb.append("}\n");
        return sb.toString();
    }
}
