/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.graph.NeighborArray;
import io.github.jbellis.jvector.graph.NeighborSimilarity;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.FixedBitSet;
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

public class ConcurrentNeighborSet {
    private final int nodeId;
    private final AtomicReference<NeighborArray> neighborsRef;
    private final float alpha;
    private final NeighborSimilarity similarity;
    private final int maxConnections;
    private float shortEdges = Float.NaN;

    public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity) {
        this(nodeId, maxConnections, similarity, 1.0f);
    }

    public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity, float alpha) {
        this(nodeId, maxConnections, similarity, alpha, new NeighborArray(maxConnections));
    }

    ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity, float alpha, NeighborArray neighbors) {
        this.nodeId = nodeId;
        this.maxConnections = maxConnections;
        this.similarity = similarity;
        this.alpha = alpha;
        this.neighborsRef = new AtomicReference<NeighborArray>(neighbors);
    }

    private ConcurrentNeighborSet(ConcurrentNeighborSet old) {
        this.nodeId = old.nodeId;
        this.maxConnections = old.maxConnections;
        this.similarity = old.similarity;
        this.alpha = old.alpha;
        this.neighborsRef = new AtomicReference<NeighborArray>(old.neighborsRef.get());
    }

    public float getShortEdges() {
        return this.shortEdges;
    }

    public NodesIterator iterator() {
        return new NeighborIterator(this.neighborsRef.get());
    }

    public void backlink(Function<Integer, ConcurrentNeighborSet> neighborhoodOf, float overflow) {
        NeighborArray neighbors = this.neighborsRef.get();
        for (int i = 0; i < neighbors.size(); ++i) {
            int nbr = neighbors.node[i];
            float nbrScore = neighbors.score[i];
            ConcurrentNeighborSet nbrNbr = neighborhoodOf.apply(nbr);
            nbrNbr.insert(this.nodeId, nbrScore, overflow);
        }
    }

    public void cleanup() {
        this.neighborsRef.getAndUpdate(this::removeAllNonDiverse);
    }

    public boolean removeDeletedNeighbors(Bits deletedNodes) {
        AtomicBoolean found = new AtomicBoolean();
        this.neighborsRef.getAndUpdate(current -> {
            FixedBitSet toRetain = new FixedBitSet(current.size);
            for (int i = 0; i < current.size; ++i) {
                if (deletedNodes.get(current.node[i])) {
                    found.set(true);
                    continue;
                }
                toRetain.set(i);
            }
            if (!found.get()) {
                return current;
            }
            NeighborArray next = current.copy();
            next.retain(toRetain);
            return next;
        });
        return found.get();
    }

    public int size() {
        return this.neighborsRef.get().size();
    }

    public int arrayLength() {
        return this.neighborsRef.get().node.length;
    }

    public void insertDiverse(NeighborArray natural, NeighborArray concurrent) {
        if (natural.size() == 0 && concurrent.size() == 0) {
            return;
        }
        this.neighborsRef.getAndUpdate(current -> {
            NeighborArray toMerge = concurrent.size == 0 ? natural : (natural.size == 0 ? concurrent : ConcurrentNeighborSet.mergeNeighbors(natural, concurrent));
            NeighborArray merged = ConcurrentNeighborSet.mergeNeighbors(current, toMerge);
            BitSet selected = this.selectDiverse(merged);
            merged.retain(selected);
            return merged;
        });
    }

    void padWithRandom(NeighborArray connections) {
        this.neighborsRef.getAndUpdate(current -> ConcurrentNeighborSet.mergeNeighbors(current, connections));
    }

    void insertNotDiverse(int node, float score, boolean limitConnections) {
        this.neighborsRef.getAndUpdate(current -> {
            NeighborArray next = current.copy();
            if (limitConnections) {
                next.size = Math.min(next.size, this.maxConnections - 1);
            }
            next.insertSorted(node, score);
            return next;
        });
    }

    private NeighborArray copyDiverse(NeighborArray merged, BitSet selected) {
        NeighborArray next = new NeighborArray(this.maxConnections);
        for (int i = 0; i < merged.size(); ++i) {
            if (!selected.get(i)) continue;
            int node = merged.node()[i];
            assert (node != this.nodeId) : "can't add self as neighbor at node " + this.nodeId;
            float score = merged.score()[i];
            next.addInOrder(node, score);
        }
        assert (next.size <= this.maxConnections);
        return next;
    }

    private BitSet selectDiverse(NeighborArray neighbors) {
        FixedBitSet selected = new FixedBitSet(neighbors.size());
        int nSelected = 0;
        float a = 1.0f;
        while ((double)a <= (double)this.alpha + 1.0E-6 && nSelected < this.maxConnections) {
            for (int i = 0; i < neighbors.size() && nSelected < this.maxConnections; ++i) {
                float cScore;
                int cNode;
                if (selected.get(i) || !this.isDiverse(cNode = neighbors.node()[i], cScore = neighbors.score()[i], neighbors, selected, a)) continue;
                ((BitSet)selected).set(i);
                ++nSelected;
            }
            if (a == 1.0f) {
                this.shortEdges = (float)nSelected / (float)this.maxConnections;
            }
            a += 0.2f;
        }
        return selected;
    }

    NeighborArray getCurrent() {
        return this.neighborsRef.get();
    }

    static NeighborArray mergeNeighbors(NeighborArray a1, NeighborArray a2) {
        NeighborArray merged = new NeighborArray(a1.size() + a2.size());
        int i = 0;
        int j = 0;
        HashSet<Integer> nodesWithLastScore = new HashSet<Integer>();
        float lastAddedScore = Float.NaN;
        while (i < a1.size() && j < a2.size()) {
            if (a1.score()[i] < a2.score[j]) {
                if (a2.score[j] != lastAddedScore) {
                    nodesWithLastScore.clear();
                    lastAddedScore = a2.score[j];
                }
                if (nodesWithLastScore.add(a2.node[j])) {
                    merged.addInOrder(a2.node[j], a2.score[j]);
                }
                ++j;
                continue;
            }
            if (a1.score()[i] > a2.score[j]) {
                if (a1.score()[i] != lastAddedScore) {
                    nodesWithLastScore.clear();
                    lastAddedScore = a1.score()[i];
                }
                if (nodesWithLastScore.add(a1.node()[i])) {
                    merged.addInOrder(a1.node()[i], a1.score()[i]);
                }
                ++i;
                continue;
            }
            if (a1.score()[i] != lastAddedScore) {
                nodesWithLastScore.clear();
                lastAddedScore = a1.score()[i];
            }
            if (nodesWithLastScore.add(a1.node()[i])) {
                merged.addInOrder(a1.node()[i], a1.score()[i]);
            }
            if (nodesWithLastScore.add(a2.node()[j])) {
                merged.addInOrder(a2.node[j], a2.score[j]);
            }
            ++i;
            ++j;
        }
        if (i < a1.size()) {
            while (i < a1.size && a1.score()[i] == lastAddedScore) {
                if (!nodesWithLastScore.contains(a1.node()[i])) {
                    merged.addInOrder(a1.node()[i], a1.score()[i]);
                }
                ++i;
            }
            System.arraycopy(a1.node, i, merged.node, merged.size, a1.size - i);
            System.arraycopy(a1.score, i, merged.score, merged.size, a1.size - i);
            merged.size += a1.size - i;
        }
        if (j < a2.size()) {
            while (j < a2.size && a2.score[j] == lastAddedScore) {
                if (!nodesWithLastScore.contains(a2.node[j])) {
                    merged.addInOrder(a2.node[j], a2.score[j]);
                }
                ++j;
            }
            System.arraycopy(a2.node, j, merged.node, merged.size, a2.size - j);
            System.arraycopy(a2.score, j, merged.score, merged.size, a2.size - j);
            merged.size += a2.size - j;
        }
        return merged;
    }

    public void insert(int neighborId, float score, float overflow) {
        assert (neighborId != this.nodeId) : "can't add self as neighbor at node " + this.nodeId;
        this.neighborsRef.getAndUpdate(current -> {
            NeighborArray next = current.copy();
            next.insertSorted(neighborId, score);
            float hardMax = overflow * (float)this.maxConnections;
            if ((float)next.size > hardMax) {
                next = this.removeAllNonDiverse(next);
            }
            return next;
        });
    }

    private boolean isDiverse(int node, float score, NeighborArray others, BitSet selected, float alpha) {
        int otherNode;
        if (others.size() == 0) {
            return true;
        }
        NeighborSimilarity.ScoreFunction scoreProvider = this.similarity.scoreProvider(node);
        int i = selected.nextSetBit(0);
        while (i != Integer.MAX_VALUE && node != (otherNode = others.node()[i])) {
            if (scoreProvider.similarityTo(otherNode) > score * alpha) {
                return false;
            }
            if (i + 1 >= selected.length()) break;
            i = selected.nextSetBit(i + 1);
        }
        return true;
    }

    private NeighborArray removeAllNonDiverse(NeighborArray neighbors) {
        if (neighbors.size <= this.maxConnections) {
            return neighbors;
        }
        BitSet selected = this.selectDiverse(neighbors);
        return this.copyDiverse(neighbors, selected);
    }

    public ConcurrentNeighborSet copy() {
        return new ConcurrentNeighborSet(this);
    }

    boolean contains(int i) {
        NodesIterator it = this.iterator();
        while (it.hasNext()) {
            if (it.nextInt() != i) continue;
            return true;
        }
        return false;
    }

    private static class NeighborIterator
    extends NodesIterator {
        private final NeighborArray neighbors;
        private int i;

        private NeighborIterator(NeighborArray neighbors) {
            super(neighbors.size());
            this.neighbors = neighbors;
            this.i = 0;
        }

        @Override
        public boolean hasNext() {
            return this.i < this.neighbors.size();
        }

        @Override
        public int nextInt() {
            return this.neighbors.node[this.i++];
        }
    }
}

