package org.apache.doris.nereids.jobs.joinorder.hypergraph;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmapSubsetIterator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.AbstractReceiver;
import org.apache.doris.qe.ConnectContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.class */
public class SubgraphEnumerator {
    public static final Logger LOG = LogManager.getLogger(SubgraphEnumerator.class);
    AbstractReceiver receiver;
    HyperGraph hyperGraph;
    EdgeCalculator edgeCalculator;
    NeighborhoodCalculator neighborhoodCalculator;
    private final boolean enableTrace = ConnectContext.get().getSessionVariable().enableDpHypTrace;
    private final StringBuilder traceBuilder = new StringBuilder();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator$EdgeCalculator.class */
    public static class EdgeCalculator {
        final List<Edge> edges;
        HashMap<Long, BitSet> containSimpleEdges = new HashMap<>();
        HashMap<Long, BitSet> containComplexEdges = new HashMap<>();
        HashMap<Long, BitSet> overlapEdges = new HashMap<>();

        EdgeCalculator(List<Edge> list) {
            this.edges = list;
        }

        public void initSubgraph(long j) {
            BitSet bitSet = new BitSet();
            BitSet bitSet2 = new BitSet();
            BitSet bitSet3 = new BitSet();
            for (Edge edge : this.edges) {
                if (isContainEdge(j, edge)) {
                    if (edge.isSimple()) {
                        bitSet.set(edge.getIndex());
                    } else {
                        bitSet2.set(edge.getIndex());
                    }
                } else if (isOverlapEdge(j, edge)) {
                    bitSet3.set(edge.getIndex());
                }
            }
            if (this.containSimpleEdges.containsKey(Long.valueOf(j))) {
                bitSet2.or(this.containComplexEdges.get(Long.valueOf(j)));
                bitSet.or(this.containSimpleEdges.get(Long.valueOf(j)));
            }
            if (this.overlapEdges.containsKey(Long.valueOf(j))) {
                bitSet3.or(this.overlapEdges.get(Long.valueOf(j)));
            }
            this.overlapEdges.put(Long.valueOf(j), bitSet3);
            this.containSimpleEdges.put(Long.valueOf(j), bitSet);
            this.containComplexEdges.put(Long.valueOf(j), bitSet2);
        }

        public void unionEdges(long j, long j2) {
            if (!this.containSimpleEdges.containsKey(Long.valueOf(j))) {
                initSubgraph(j);
            }
            if (!this.containSimpleEdges.containsKey(Long.valueOf(j2))) {
                initSubgraph(j2);
            }
            long newBitmapUnion = LongBitmap.newBitmapUnion(j, j2);
            if (this.containSimpleEdges.containsKey(Long.valueOf(newBitmapUnion))) {
                return;
            }
            BitSet bitSet = new BitSet();
            bitSet.or(this.containSimpleEdges.get(Long.valueOf(j)));
            bitSet.or(this.containSimpleEdges.get(Long.valueOf(j2)));
            BitSet bitSet2 = new BitSet();
            bitSet2.or(this.containComplexEdges.get(Long.valueOf(j)));
            bitSet2.or(this.containComplexEdges.get(Long.valueOf(j2)));
            BitSet bitSet3 = new BitSet();
            bitSet3.or(this.overlapEdges.get(Long.valueOf(j)));
            bitSet3.or(this.overlapEdges.get(Long.valueOf(j2)));
            for (int i : bitSet3.stream().toArray()) {
                Edge edge = this.edges.get(i);
                if (isContainEdge(newBitmapUnion, edge)) {
                    bitSet3.set(i, false);
                    if (edge.isSimple()) {
                        bitSet.set(i);
                    } else {
                        bitSet2.set(i);
                    }
                }
            }
            BitSet removeInvalidEdges = removeInvalidEdges(newBitmapUnion, bitSet);
            BitSet removeInvalidEdges2 = removeInvalidEdges(newBitmapUnion, bitSet2);
            this.containSimpleEdges.put(Long.valueOf(newBitmapUnion), removeInvalidEdges);
            this.containComplexEdges.put(Long.valueOf(newBitmapUnion), removeInvalidEdges2);
            this.overlapEdges.put(Long.valueOf(newBitmapUnion), bitSet3);
        }

        public List<Edge> connectCsgCmp(long j, long j2) {
            Preconditions.checkArgument(this.containSimpleEdges.containsKey(Long.valueOf(j)) && this.containSimpleEdges.containsKey(Long.valueOf(j2)));
            ArrayList arrayList = new ArrayList();
            BitSet bitSet = new BitSet();
            bitSet.or(this.containSimpleEdges.get(Long.valueOf(j)));
            bitSet.and(this.containSimpleEdges.get(Long.valueOf(j2)));
            BitSet bitSet2 = new BitSet();
            bitSet2.or(this.containComplexEdges.get(Long.valueOf(j)));
            bitSet2.and(this.containComplexEdges.get(Long.valueOf(j2)));
            bitSet.or(bitSet2);
            bitSet.stream().forEach(i -> {
                arrayList.add(this.edges.get(i));
            });
            return arrayList;
        }

        public List<Edge> foundEdgesContain(long j) {
            BitSet bitSet = this.containSimpleEdges.get(Long.valueOf(j));
            Preconditions.checkState(bitSet != null);
            bitSet.or(this.containComplexEdges.get(Long.valueOf(j)));
            IntStream stream = bitSet.stream();
            List<Edge> list = this.edges;
            list.getClass();
            return (List) stream.mapToObj(list::get).collect(Collectors.toList());
        }

        public List<Edge> foundSimpleEdgesContain(long j) {
            if (!this.containSimpleEdges.containsKey(Long.valueOf(j))) {
                return Collections.emptyList();
            }
            IntStream stream = this.containSimpleEdges.get(Long.valueOf(j)).stream();
            List<Edge> list = this.edges;
            list.getClass();
            return (List) stream.mapToObj(list::get).collect(Collectors.toList());
        }

        public List<Edge> foundComplexEdgesContain(long j) {
            if (!this.containComplexEdges.containsKey(Long.valueOf(j))) {
                return Collections.emptyList();
            }
            IntStream stream = this.containComplexEdges.get(Long.valueOf(j)).stream();
            List<Edge> list = this.edges;
            list.getClass();
            return (List) stream.mapToObj(list::get).collect(Collectors.toList());
        }

        private boolean isContainEdge(long j, Edge edge) {
            return (LongBitmap.isSubset(edge.getLeft(), j) ? 0 : 1) + (LongBitmap.isSubset(edge.getRight(), j) ? 0 : 1) == 1;
        }

        private boolean isOverlapEdge(long j, Edge edge) {
            return (LongBitmap.isOverlap(edge.getLeft(), j) ? 0 : 1) + (LongBitmap.isOverlap(edge.getRight(), j) ? 0 : 1) == 1;
        }

        private BitSet removeInvalidEdges(long j, BitSet bitSet) {
            for (int i : bitSet.stream().toArray()) {
                if (!isOverlapEdge(j, this.edges.get(i))) {
                    bitSet.set(i, false);
                }
            }
            return bitSet;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator$NeighborhoodCalculator.class */
    public static class NeighborhoodCalculator {
        NeighborhoodCalculator() {
        }

        public long calcNeighborhood(long j, long j2, EdgeCalculator edgeCalculator) {
            long newBitmap = LongBitmap.newBitmap();
            Iterator<Edge> it = edgeCalculator.foundSimpleEdgesContain(j).iterator();
            while (it.hasNext()) {
                newBitmap = LongBitmap.or(newBitmap, it.next().getReferenceNodes());
            }
            long or = LongBitmap.or(j2, j);
            long andNot = LongBitmap.andNot(newBitmap, or);
            long or2 = LongBitmap.or(or, andNot);
            for (Edge edge : edgeCalculator.foundComplexEdgesContain(j)) {
                long left = edge.getLeft();
                long right = edge.getRight();
                if (LongBitmap.isSubset(left, j) && !LongBitmap.isOverlap(right, or2)) {
                    andNot = LongBitmap.set(andNot, LongBitmap.lowestOneIndex(right));
                } else if (LongBitmap.isSubset(right, j) && !LongBitmap.isOverlap(left, or2)) {
                    andNot = LongBitmap.set(andNot, LongBitmap.lowestOneIndex(left));
                }
            }
            return andNot;
        }
    }

    public SubgraphEnumerator(AbstractReceiver abstractReceiver, HyperGraph hyperGraph) {
        this.receiver = abstractReceiver;
        this.hyperGraph = hyperGraph;
    }

    public boolean enumerate() {
        if (this.enableTrace) {
            this.traceBuilder.append("Query Graph Graphviz: ").append(this.hyperGraph.toDottyHyperGraph()).append("\n");
        }
        this.receiver.reset();
        List<Node> nodes = this.hyperGraph.getNodes();
        for (Node node : nodes) {
            this.receiver.addGroup(node.getNodeMap(), node.getGroup());
        }
        int size = nodes.size();
        this.edgeCalculator = new EdgeCalculator(this.hyperGraph.getEdges());
        Iterator<Node> it = nodes.iterator();
        while (it.hasNext()) {
            this.edgeCalculator.initSubgraph(it.next().getNodeMap());
        }
        this.neighborhoodCalculator = new NeighborhoodCalculator();
        long newBitmapBetween = LongBitmap.newBitmapBetween(0, size - 1);
        for (int i = size - 2; i >= 0; i--) {
            if (this.enableTrace) {
                this.traceBuilder.append("Starting main iteration at node[").append(i).append("]\n");
            }
            long newBitmap = LongBitmap.newBitmap(i);
            newBitmapBetween = LongBitmap.unset(newBitmapBetween, i);
            if (!emitCsg(newBitmap) || !enumerateCsgRec(newBitmap, LongBitmap.clone(newBitmapBetween))) {
                return false;
            }
        }
        if (!this.enableTrace) {
            return true;
        }
        LOG.info(this.traceBuilder.toString());
        return true;
    }

    private boolean enumerateCsgRec(long j, long j2) {
        long calcNeighborhood = this.neighborhoodCalculator.calcNeighborhood(j, j2, this.edgeCalculator);
        LongBitmapSubsetIterator subsetIterator = LongBitmap.getSubsetIterator(calcNeighborhood);
        if (this.enableTrace) {
            this.traceBuilder.append("Expanding connected subgraph, subgraph=[").append(LongBitmap.toString(j)).append("], neighborhood=[").append(LongBitmap.toString(calcNeighborhood)).append("], forbidden=[").append(LongBitmap.toString(j2)).append("]\n");
        }
        Iterator<Long> it = subsetIterator.iterator();
        while (it.hasNext()) {
            long longValue = it.next().longValue();
            long newBitmapUnion = LongBitmap.newBitmapUnion(j, longValue);
            this.edgeCalculator.unionEdges(j, longValue);
            if (this.receiver.contain(newBitmapUnion) && !emitCsg(newBitmapUnion)) {
                return false;
            }
        }
        long or = LongBitmap.or(j2, calcNeighborhood);
        subsetIterator.reset();
        Iterator<Long> it2 = subsetIterator.iterator();
        while (it2.hasNext()) {
            if (!enumerateCsgRec(LongBitmap.newBitmapUnion(j, it2.next().longValue()), LongBitmap.clone(or))) {
                return false;
            }
        }
        return true;
    }

    private boolean enumerateCmpRec(long j, long j2, long j3) {
        long calcNeighborhood = this.neighborhoodCalculator.calcNeighborhood(j2, j3, this.edgeCalculator);
        LongBitmapSubsetIterator longBitmapSubsetIterator = new LongBitmapSubsetIterator(calcNeighborhood);
        if (this.enableTrace) {
            this.traceBuilder.append("Expanding complement subgraph, subgraph=[").append(LongBitmap.toString(j2)).append("], neighborhood=[").append(LongBitmap.toString(calcNeighborhood)).append("], forbidden=[").append(LongBitmap.toString(j3)).append("]\n");
        }
        Iterator<Long> it = longBitmapSubsetIterator.iterator();
        while (it.hasNext()) {
            long longValue = it.next().longValue();
            long newBitmapUnion = LongBitmap.newBitmapUnion(j2, longValue);
            this.edgeCalculator.unionEdges(j2, longValue);
            if (this.receiver.contain(newBitmapUnion)) {
                List<Edge> connectCsgCmp = this.edgeCalculator.connectCsgCmp(j, newBitmapUnion);
                if (!connectCsgCmp.isEmpty() && !this.receiver.emitCsgCmp(j, newBitmapUnion, connectCsgCmp)) {
                    return false;
                }
            }
        }
        long or = LongBitmap.or(j3, calcNeighborhood);
        longBitmapSubsetIterator.reset();
        Iterator<Long> it2 = longBitmapSubsetIterator.iterator();
        while (it2.hasNext()) {
            if (!enumerateCmpRec(j, LongBitmap.newBitmapUnion(j2, it2.next().longValue()), LongBitmap.clone(or))) {
                return false;
            }
        }
        return true;
    }

    private boolean emitCsg(long j) {
        long or = LongBitmap.or(LongBitmap.newBitmapBetween(0, LongBitmap.nextSetBit(j, 0)), j);
        long calcNeighborhood = this.neighborhoodCalculator.calcNeighborhood(j, LongBitmap.clone(or), this.edgeCalculator);
        if (this.enableTrace && LongBitmap.getCardinality(j) == 1) {
            this.traceBuilder.append("Emitting connected subgraph, subgraph=[").append(LongBitmap.toString(j)).append("], neighborhood=[").append(LongBitmap.toString(calcNeighborhood)).append("], forbidden=[").append(LongBitmap.toString(or)).append("]\n");
        }
        Iterator<Integer> it = LongBitmap.getReverseIterator(calcNeighborhood).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            long newBitmap = LongBitmap.newBitmap(intValue);
            List<Edge> connectCsgCmp = this.edgeCalculator.connectCsgCmp(j, newBitmap);
            if ((!connectCsgCmp.isEmpty() && !this.receiver.emitCsgCmp(j, newBitmap, connectCsgCmp)) || !enumerateCmpRec(j, newBitmap, LongBitmap.or(LongBitmap.and(LongBitmap.newBitmapBetween(0, intValue + 1), calcNeighborhood), or))) {
                return false;
            }
        }
        return true;
    }
}
