/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.Experimental;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.ConcurrentNeighborMap;
import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
import io.github.jbellis.jvector.graph.MutableGraphIndex;
import io.github.jbellis.jvector.graph.NodeArray;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.diversity.DiversityProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.util.Accountable;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.DenseIntMap;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.util.SparseIntMap;
import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.PrimitiveIterator;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.StampedLock;
import java.util.stream.IntStream;
import org.agrona.collections.IntArrayList;

public class OnHeapGraphIndex
implements MutableGraphIndex {
    public static final int MAGIC = 1978417170;
    private final AtomicReference<ImmutableGraphIndex.NodeAtLevel> entryPoint;
    final List<ConcurrentNeighborMap> layers = new ArrayList<ConcurrentNeighborMap>();
    private final CompletionTracker completions;
    private final ThreadSafeGrowableBitSet deletedNodes = new ThreadSafeGrowableBitSet(0);
    private final AtomicInteger maxNodeId = new AtomicInteger(-1);
    final List<Integer> maxDegrees;
    private final int dimension;
    private final double overflowRatio;
    private volatile boolean allMutationsCompleted = false;
    private final boolean isHierarchical;

    OnHeapGraphIndex(List<Integer> maxDegrees, int dimension, double overflowRatio, DiversityProvider diversityProvider, boolean isHierarchical) {
        this.overflowRatio = overflowRatio;
        this.maxDegrees = new IntArrayList();
        this.dimension = dimension;
        this.setDegrees(maxDegrees);
        this.entryPoint = new AtomicReference();
        this.completions = new CompletionTracker(1024);
        this.layers.add(new ConcurrentNeighborMap(new DenseIntMap<ConcurrentNeighborMap.Neighbors>(1024), diversityProvider, this.getDegree(0), (int)((double)this.getDegree(0) * overflowRatio)));
        this.isHierarchical = isHierarchical;
    }

    ConcurrentNeighborMap.Neighbors getNeighbors(int level, int node) {
        if (level >= this.layers.size()) {
            return null;
        }
        return this.layers.get(level).get(node);
    }

    @Override
    public NodesIterator getNeighborsIterator(ImmutableGraphIndex.NodeAtLevel nodeAtLevel) {
        return this.getNeighborsIterator(nodeAtLevel.level, nodeAtLevel.node);
    }

    @Override
    public NodesIterator getNeighborsIterator(int level, int node) {
        if (level >= this.layers.size()) {
            return NodesIterator.EMPTY_NODE_ITERATOR;
        }
        ConcurrentNeighborMap.Neighbors neighs = this.layers.get(level).get(node);
        if (neighs == null) {
            return NodesIterator.EMPTY_NODE_ITERATOR;
        }
        return neighs.iterator();
    }

    @Override
    public boolean isHierarchical() {
        return this.isHierarchical;
    }

    @Override
    public int getMaxLevelForNode(int node) {
        int maxLayer = -1;
        int lvl = 0;
        while (lvl < this.layers.size() && this.getNeighbors(lvl, node) != null) {
            maxLayer = lvl++;
        }
        return maxLayer;
    }

    @Override
    public int size(int level) {
        return this.layers.get(level).size();
    }

    @Override
    public void addNode(ImmutableGraphIndex.NodeAtLevel nodeLevel) {
        this.addNode(nodeLevel.level, nodeLevel.node);
    }

    @Override
    public void addNode(int level, int node) {
        this.ensureLayersExist(level);
        for (int i = 0; i <= level; ++i) {
            this.layers.get(i).addNode(node);
        }
        this.maxNodeId.accumulateAndGet(node, Math::max);
    }

    @Override
    public boolean contains(ImmutableGraphIndex.NodeAtLevel nodeLevel) {
        return this.contains(nodeLevel.level, nodeLevel.node);
    }

    @Override
    public boolean contains(int level, int node) {
        return this.layers.get(level).contains(node);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void ensureLayersExist(int level) {
        for (int i = this.layers.size(); i <= level; ++i) {
            List<ConcurrentNeighborMap> list = this.layers;
            synchronized (list) {
                if (i == this.layers.size()) {
                    ConcurrentNeighborMap denseMap = this.layers.get(0);
                    ConcurrentNeighborMap map = new ConcurrentNeighborMap(new SparseIntMap<ConcurrentNeighborMap.Neighbors>(), denseMap.diversityProvider, this.getDegree(level), (int)((double)this.getDegree(level) * this.overflowRatio));
                    this.layers.add(map);
                }
                continue;
            }
        }
    }

    @Override
    public void connectNode(ImmutableGraphIndex.NodeAtLevel nodeLevel, NodeArray nodes) {
        this.connectNode(nodeLevel.level, nodeLevel.node, nodes);
    }

    @Override
    public void connectNode(int level, int node, NodeArray nodes) {
        assert (nodes != null);
        this.ensureLayersExist(level);
        this.layers.get(level).addNode(node, nodes);
        this.maxNodeId.accumulateAndGet(node, Math::max);
    }

    @Override
    public void markDeleted(int node) {
        this.deletedNodes.set(node);
    }

    @Override
    public void markComplete(ImmutableGraphIndex.NodeAtLevel nodeLevel) {
        this.entryPoint.accumulateAndGet(nodeLevel, (oldEntry, newEntry) -> {
            if (oldEntry == null || newEntry.level > oldEntry.level) {
                return newEntry;
            }
            return oldEntry;
        });
        this.completions.markComplete(nodeLevel.node);
    }

    @Override
    public void updateEntryNode(ImmutableGraphIndex.NodeAtLevel newEntry) {
        this.entryPoint.set(newEntry);
    }

    @Override
    public ImmutableGraphIndex.NodeAtLevel entryNode() {
        return this.entryPoint.get();
    }

    @Override
    public NodesIterator getNodes(int level) {
        return NodesIterator.fromPrimitiveIterator(this.nodeStream(level).iterator(), this.layers.get(level).size());
    }

    @Override
    public IntStream nodeStream(int level) {
        ConcurrentNeighborMap layer = this.layers.get(level);
        return level == 0 ? IntStream.range(0, this.getIdUpperBound()).filter(i -> layer.get(i) != null) : ((SparseIntMap)layer.neighbors).keysStream();
    }

    @Override
    public long ramBytesUsed() {
        long graphBytesUsed = IntStream.range(0, this.layers.size()).mapToLong(this::ramBytesUsedOneLayer).sum();
        return graphBytesUsed + this.completions.ramBytesUsed();
    }

    private long ramBytesUsedOneLayer(int level) {
        int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
        int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
        int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
        long neighborSize = this.ramBytesUsedOneNode(level) * (long)this.layers.get(level).size();
        return (long)OH_BYTES + (long)REF_BYTES * 2L + (long)AH_BYTES + neighborSize;
    }

    @Override
    public long ramBytesUsedOneNode(int level) {
        int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
        return (long)REF_BYTES + ConcurrentNeighborMap.Neighbors.ramBytesUsed(this.layers.get(level).nodeArrayLength());
    }

    @Override
    public void enforceDegree(int node) {
        for (int level = 0; level <= this.getMaxLevel(); ++level) {
            this.layers.get(level).enforceDegree(node);
        }
    }

    @Override
    public void addEdges(int level, int node, NodeArray candidates, float overflowRatio) {
        ConcurrentNeighborMap.Neighbors newNeighbors = this.layers.get(level).insertDiverse(node, candidates);
        this.layers.get(level).backlink(newNeighbors, node, overflowRatio);
    }

    @Override
    public void replaceDeletedNeighbors(int level, int node, BitSet toDelete, NodeArray candidates) {
        this.layers.get(level).replaceDeletedNeighbors(node, toDelete, candidates);
    }

    public String toString() {
        return String.format("OnHeapGraphIndex(size=%d, entryPoint=%s)", this.size(0), this.entryPoint.get());
    }

    @Override
    public void close() {
    }

    @Override
    public ImmutableGraphIndex.View getView() {
        if (this.allMutationsCompleted) {
            return new FrozenView();
        }
        return new ConcurrentGraphIndexView();
    }

    @Override
    public ThreadSafeGrowableBitSet getDeletedNodes() {
        return this.deletedNodes;
    }

    @Override
    public int removeNode(int node) {
        int found = 0;
        for (ConcurrentNeighborMap layer : this.layers) {
            if (layer.remove(node) == null) continue;
            ++found;
        }
        this.deletedNodes.clear(node);
        return found;
    }

    @Override
    public int getIdUpperBound() {
        return this.maxNodeId.get() + 1;
    }

    @Override
    public boolean containsNode(int nodeId) {
        return this.layers.get(0).contains(nodeId);
    }

    @Override
    public double getAverageDegree(int level) {
        return this.nodeStream(level).mapToDouble(i -> this.getNeighbors(level, i).size()).average().orElse(Double.NaN);
    }

    @Override
    public int getMaxLevel() {
        for (int lvl = 0; lvl < this.layers.size(); ++lvl) {
            if (this.layers.get(lvl).size() != 0) continue;
            return lvl - 1;
        }
        return this.layers.size() - 1;
    }

    @Override
    public int getDegree(int level) {
        if (level >= this.maxDegrees.size()) {
            return this.maxDegrees.get(this.maxDegrees.size() - 1);
        }
        return this.maxDegrees.get(level);
    }

    @Override
    public int maxDegree() {
        return this.maxDegrees.stream().mapToInt(i -> i).max().orElseThrow();
    }

    @Override
    public List<Integer> maxDegrees() {
        return this.maxDegrees;
    }

    @Override
    public void setDegrees(List<Integer> layerDegrees) {
        this.maxDegrees.clear();
        this.maxDegrees.addAll(layerDegrees);
    }

    @Override
    public int getDimension() {
        return this.dimension;
    }

    @Override
    public void setAllMutationsCompleted() {
        this.allMutationsCompleted = true;
    }

    @Override
    public boolean allMutationsCompleted() {
        return this.allMutationsCompleted;
    }

    @Deprecated
    @Experimental
    public void save(DataOutput out) throws IOException {
        if (!this.allMutationsCompleted()) {
            throw new IllegalStateException("Cannot save a graph with pending mutations. Call cleanup() first");
        }
        out.writeInt(1978417170);
        out.writeInt(4);
        out.writeInt(this.layers.size());
        for (int level = 0; level < this.layers.size(); ++level) {
            out.writeInt(this.getDegree(level));
        }
        ImmutableGraphIndex.NodeAtLevel entryNode = this.entryPoint.get();
        assert (entryNode.level == this.getMaxLevel());
        out.writeInt(entryNode.node);
        for (int level = 0; level < this.layers.size(); ++level) {
            out.writeInt(this.size(level));
            PrimitiveIterator.OfInt it = this.nodeStream(level).iterator();
            while (it.hasNext()) {
                int nodeId = it.nextInt();
                ConcurrentNeighborMap.Neighbors neighbors = this.layers.get(level).get(nodeId);
                out.writeInt(nodeId);
                out.writeInt(neighbors.size());
                for (int n = 0; n < neighbors.size(); ++n) {
                    out.writeInt(neighbors.getNode(n));
                    out.writeFloat(neighbors.getScore(n));
                }
            }
        }
    }

    @Deprecated
    @Experimental
    public static OnHeapGraphIndex load(RandomAccessReader in, int dimension, double overflowRatio, DiversityProvider diversityProvider) throws IOException {
        int magic = in.readInt();
        if (magic != 1978417170) {
            throw new IOException("Unsupported magic number: " + magic);
        }
        int version = in.readInt();
        if (version != 4) {
            throw new IOException("Unsupported version: " + version);
        }
        int layerCount = in.readInt();
        ArrayList<Integer> layerDegrees = new ArrayList<Integer>(layerCount);
        for (int level = 0; level < layerCount; ++level) {
            layerDegrees.add(in.readInt());
        }
        int entryNode = in.readInt();
        boolean isHierarchical = layerCount > 1;
        OnHeapGraphIndex graph = new OnHeapGraphIndex(layerDegrees, dimension, overflowRatio, diversityProvider, isHierarchical);
        HashMap<Integer, Integer> nodeLevelMap = new HashMap<Integer, Integer>();
        for (int level = 0; level < layerCount; ++level) {
            int layerSize = in.readInt();
            for (int i = 0; i < layerSize; ++i) {
                int nodeId = in.readInt();
                int nNeighbors = in.readInt();
                NodeArray ca = new NodeArray(nNeighbors);
                for (int j = 0; j < nNeighbors; ++j) {
                    int neighbor = in.readInt();
                    float score = in.readFloat();
                    ca.addInOrder(neighbor, score);
                }
                graph.connectNode(level, nodeId, ca);
                nodeLevelMap.put(nodeId, level);
            }
        }
        for (Integer k : nodeLevelMap.keySet()) {
            ImmutableGraphIndex.NodeAtLevel nal = new ImmutableGraphIndex.NodeAtLevel((Integer)nodeLevelMap.get(k), k);
            graph.markComplete(nal);
        }
        graph.setDegrees(layerDegrees);
        graph.updateEntryNode(new ImmutableGraphIndex.NodeAtLevel(graph.getMaxLevel(), entryNode));
        return graph;
    }

    static final class CompletionTracker
    implements Accountable {
        private final AtomicInteger logicalClock = new AtomicInteger();
        private volatile AtomicIntegerArray completionTimes;
        private final StampedLock sl = new StampedLock();

        public CompletionTracker(int initialSize) {
            this.completionTimes = new AtomicIntegerArray(initialSize);
            for (int i = 0; i < initialSize; ++i) {
                this.completionTimes.set(i, Integer.MAX_VALUE);
            }
        }

        void markComplete(int node) {
            long stamp;
            int completionClock = this.logicalClock.getAndIncrement();
            this.ensureCapacity(node);
            do {
                stamp = this.sl.tryOptimisticRead();
                this.completionTimes.set(node, completionClock);
            } while (!this.sl.validate(stamp));
        }

        int clock() {
            return this.logicalClock.get();
        }

        public int completedAt(int node) {
            AtomicIntegerArray ct = this.completionTimes;
            if (node >= ct.length()) {
                return Integer.MAX_VALUE;
            }
            return ct.get(node);
        }

        @Override
        public long ramBytesUsed() {
            int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
            return (long)(REF_BYTES + 4 + REF_BYTES) + 4L * (long)this.completionTimes.length();
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private void ensureCapacity(int node) {
            if (node < this.completionTimes.length()) {
                return;
            }
            long stamp = this.sl.writeLock();
            try {
                AtomicIntegerArray oldArray = this.completionTimes;
                if (node >= oldArray.length()) {
                    int newSize = (node + 1) * 2;
                    AtomicIntegerArray newArr = new AtomicIntegerArray(newSize);
                    for (int i = 0; i < newSize; ++i) {
                        if (i < oldArray.length()) {
                            newArr.set(i, oldArray.get(i));
                            continue;
                        }
                        newArr.set(i, Integer.MAX_VALUE);
                    }
                    this.completionTimes = newArr;
                }
            }
            finally {
                this.sl.unlockWrite(stamp);
            }
        }
    }

    private class FrozenView
    implements ImmutableGraphIndex.View {
        private FrozenView() {
        }

        @Override
        public NodesIterator getNeighborsIterator(int level, int node) {
            return OnHeapGraphIndex.this.getNeighborsIterator(level, node);
        }

        @Override
        public void processNeighbors(int level, int node, ScoreFunction scoreFunction, ImmutableGraphIndex.IntMarker visited, ImmutableGraphIndex.NeighborProcessor neighborProcessor) {
            NodesIterator it = this.getNeighborsIterator(level, node);
            while (it.hasNext()) {
                int friendOrd = it.nextInt();
                if (!visited.mark(friendOrd)) continue;
                float friendSimilarity = scoreFunction.similarityTo(friendOrd);
                neighborProcessor.process(friendOrd, friendSimilarity);
            }
        }

        @Override
        public int size() {
            return OnHeapGraphIndex.this.size(0);
        }

        @Override
        public ImmutableGraphIndex.NodeAtLevel entryNode() {
            return OnHeapGraphIndex.this.entryPoint.get();
        }

        @Override
        public Bits liveNodes() {
            return OnHeapGraphIndex.this.deletedNodes.cardinality() == 0 ? Bits.ALL : Bits.inverseOf(OnHeapGraphIndex.this.deletedNodes);
        }

        @Override
        public int getIdUpperBound() {
            return OnHeapGraphIndex.this.getIdUpperBound();
        }

        @Override
        public boolean contains(int level, int node) {
            return OnHeapGraphIndex.this.contains(level, node);
        }

        @Override
        public void close() {
        }

        public String toString() {
            ImmutableGraphIndex.NodeAtLevel entry = this.entryNode();
            return String.format("%s(size=%d, entryNode=%s)", this.getClass().getSimpleName(), this.size(), entry);
        }
    }

    public class ConcurrentGraphIndexView
    extends FrozenView {
        private final int timestamp;

        public ConcurrentGraphIndexView() {
            this.timestamp = OnHeapGraphIndex.this.completions.clock();
        }

        @Override
        public NodesIterator getNeighborsIterator(final int level, final int node) {
            final NodesIterator it = OnHeapGraphIndex.this.getNeighborsIterator(level, node);
            return new NodesIterator(){
                int nextNode = this.advance();
                final /* synthetic */ ConcurrentGraphIndexView this$1;
                {
                    this.this$1 = this$1;
                }

                private int advance() {
                    while (it.hasNext()) {
                        int n = it.nextInt();
                        if (this.this$1.OnHeapGraphIndex.this.completions.completedAt(n) >= this.this$1.timestamp) continue;
                        return n;
                    }
                    return Integer.MIN_VALUE;
                }

                @Override
                public int size() {
                    NodesIterator it2 = this.this$1.OnHeapGraphIndex.this.getNeighborsIterator(level, node);
                    int size = 0;
                    while (it2.hasNext()) {
                        int n = it2.nextInt();
                        if (this.this$1.OnHeapGraphIndex.this.completions.completedAt(n) >= this.this$1.timestamp) continue;
                        ++size;
                    }
                    return size;
                }

                @Override
                public int nextInt() {
                    int current = this.nextNode;
                    if (current == Integer.MIN_VALUE) {
                        throw new NoSuchElementException();
                    }
                    this.nextNode = this.advance();
                    return current;
                }

                @Override
                public boolean hasNext() {
                    return this.nextNode != Integer.MIN_VALUE;
                }
            };
        }
    }
}

