/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.graph.ConcurrentNeighborMap;
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.graph.NodeArray;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.util.Accountable;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.DenseIntMap;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.util.SparseIntMap;
import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import java.io.DataOutput;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.StampedLock;
import java.util.stream.IntStream;
import org.agrona.collections.IntArrayList;

public class OnHeapGraphIndex
implements GraphIndex {
    private final AtomicReference<GraphIndex.NodeAtLevel> entryPoint;
    final List<ConcurrentNeighborMap> layers = new ArrayList<ConcurrentNeighborMap>();
    private final CompletionTracker completions;
    private final ThreadSafeGrowableBitSet deletedNodes = new ThreadSafeGrowableBitSet(0);
    private final AtomicInteger maxNodeId = new AtomicInteger(-1);
    final IntArrayList maxDegrees;
    private final double overflowRatio;
    public final ConcurrentMap<GraphIndex.NodeAtLevel, VectorFloat<?>> constructionBatch;

    OnHeapGraphIndex(List<Integer> maxDegrees, double overflowRatio, BuildScoreProvider scoreProvider, float alpha, int batchSize) {
        this.overflowRatio = overflowRatio;
        this.maxDegrees = new IntArrayList();
        this.setDegrees(maxDegrees);
        this.entryPoint = new AtomicReference();
        this.completions = new CompletionTracker(1024);
        this.layers.add(new ConcurrentNeighborMap(new DenseIntMap<ConcurrentNeighborMap.Neighbors>(1024), scoreProvider, this.getDegree(0), (int)((double)this.getDegree(0) * overflowRatio), alpha));
        this.constructionBatch = new ConcurrentHashMap(batchSize);
    }

    ConcurrentNeighborMap.Neighbors getNeighbors(int level, int node) {
        if (level >= this.layers.size()) {
            return null;
        }
        return this.layers.get(level).get(node);
    }

    @Override
    public int size(int level) {
        return this.layers.get(level).size();
    }

    public void addNode(GraphIndex.NodeAtLevel nodeLevel) {
        this.ensureLayersExist(nodeLevel.level);
        for (int i = 0; i <= nodeLevel.level; ++i) {
            this.layers.get(i).addNode(nodeLevel.node);
        }
        this.maxNodeId.accumulateAndGet(nodeLevel.node, Math::max);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void ensureLayersExist(int level) {
        for (int i = this.layers.size(); i <= level; ++i) {
            List<ConcurrentNeighborMap> list = this.layers;
            synchronized (list) {
                if (i == this.layers.size()) {
                    ConcurrentNeighborMap denseMap = this.layers.get(0);
                    ConcurrentNeighborMap map = new ConcurrentNeighborMap(new SparseIntMap<ConcurrentNeighborMap.Neighbors>(), denseMap.scoreProvider, this.getDegree(level), (int)((double)this.getDegree(level) * this.overflowRatio), denseMap.alpha);
                    this.layers.add(map);
                }
                continue;
            }
        }
    }

    void addNode(int level, int nodeId, NodeArray nodes) {
        assert (nodes != null);
        this.ensureLayersExist(level);
        this.layers.get(level).addNode(nodeId, nodes);
        this.maxNodeId.accumulateAndGet(nodeId, Math::max);
    }

    public void markDeleted(int node) {
        this.deletedNodes.set(node);
    }

    void markComplete(GraphIndex.NodeAtLevel nodeLevel) {
        this.entryPoint.accumulateAndGet(nodeLevel, (oldEntry, newEntry) -> {
            if (oldEntry == null || newEntry.level > oldEntry.level) {
                return newEntry;
            }
            return oldEntry;
        });
        this.completions.markComplete(nodeLevel.node);
    }

    void updateEntryNode(GraphIndex.NodeAtLevel newEntry) {
        this.entryPoint.set(newEntry);
    }

    GraphIndex.NodeAtLevel entry() {
        return this.entryPoint.get();
    }

    @Override
    public NodesIterator getNodes(int level) {
        return NodesIterator.fromPrimitiveIterator(this.nodeStream(level).iterator(), this.layers.get(level).size());
    }

    IntStream nodeStream(int level) {
        ConcurrentNeighborMap layer = this.layers.get(level);
        return level == 0 ? IntStream.range(0, this.getIdUpperBound()).filter(i -> layer.get(i) != null) : ((SparseIntMap)layer.neighbors).keysStream();
    }

    @Override
    public long ramBytesUsed() {
        long graphBytesUsed = IntStream.range(0, this.layers.size()).mapToLong(this::ramBytesUsedOneLayer).sum();
        return graphBytesUsed + this.completions.ramBytesUsed();
    }

    public long ramBytesUsedOneLayer(int layer) {
        int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
        int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
        int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
        long neighborSize = this.ramBytesUsedOneNode(layer) * (long)this.layers.get(layer).size();
        return (long)OH_BYTES + (long)REF_BYTES * 2L + (long)AH_BYTES + neighborSize;
    }

    public long ramBytesUsedOneNode(int layer) {
        int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
        return (long)REF_BYTES + ConcurrentNeighborMap.Neighbors.ramBytesUsed(this.layers.get(layer).nodeArrayLength());
    }

    public String toString() {
        return String.format("OnHeapGraphIndex(size=%d, entryPoint=%s)", this.size(0), this.entryPoint.get());
    }

    @Override
    public void close() {
    }

    @Override
    public ConcurrentGraphIndexView getView() {
        return new ConcurrentGraphIndexView();
    }

    public GraphIndex.View getFrozenView() {
        return new FrozenView();
    }

    void validateEntryNode() {
        if (this.size(0) == 0) {
            return;
        }
        GraphIndex.NodeAtLevel entry = this.getView().entryNode();
        if (entry == null || this.getNeighbors(entry.level, entry.node) == null) {
            throw new IllegalStateException("Entry node was incompletely added! " + String.valueOf(entry));
        }
    }

    public ThreadSafeGrowableBitSet getDeletedNodes() {
        return this.deletedNodes;
    }

    int removeNode(int node) {
        int found = 0;
        for (ConcurrentNeighborMap layer : this.layers) {
            if (layer.remove(node) == null) continue;
            ++found;
        }
        this.deletedNodes.clear(node);
        return found;
    }

    @Override
    public int getIdUpperBound() {
        return this.maxNodeId.get() + 1;
    }

    @Override
    public boolean containsNode(int nodeId) {
        return this.layers.get(0).contains(nodeId);
    }

    public double getAverageDegree(int level) {
        return this.nodeStream(level).mapToDouble(i -> this.getNeighbors(level, i).size()).average().orElse(Double.NaN);
    }

    @Override
    public int getMaxLevel() {
        for (int lvl = 0; lvl < this.layers.size(); ++lvl) {
            if (this.layers.get(lvl).size() != 0) continue;
            return lvl - 1;
        }
        return this.layers.size() - 1;
    }

    @Override
    public int getDegree(int level) {
        if (level >= this.maxDegrees.size()) {
            return this.maxDegrees.get(this.maxDegrees.size() - 1);
        }
        return this.maxDegrees.get(level);
    }

    @Override
    public int maxDegree() {
        return this.maxDegrees.stream().mapToInt(i -> i).max().orElseThrow();
    }

    public int getLayerSize(int level) {
        return this.layers.get(level).size();
    }

    public void setDegrees(List<Integer> layerDegrees) {
        this.maxDegrees.clear();
        this.maxDegrees.addAll(layerDegrees);
    }

    public void save(DataOutput out) {
        if (this.deletedNodes.cardinality() > 0) {
            throw new IllegalStateException("Cannot save a graph that has deleted nodes. Call cleanup() first");
        }
        try (ConcurrentGraphIndexView view = this.getView();){
            out.writeInt(this.layers.size());
            assert (view.entryNode().level == this.getMaxLevel());
            out.writeInt(view.entryNode().node);
            for (int level = 0; level < this.layers.size(); ++level) {
                out.writeInt(this.size(level));
                out.writeInt(this.getDegree(level));
                ConcurrentNeighborMap baseLayer = this.layers.get(level);
                baseLayer.forEach((nodeId, neighbors) -> {
                    try {
                        NodesIterator iterator = neighbors.iterator();
                        out.writeInt(nodeId);
                        out.writeInt(iterator.size());
                        for (int n = 0; n < iterator.size(); ++n) {
                            out.writeInt(iterator.nextInt());
                        }
                        assert (!iterator.hasNext());
                    }
                    catch (IOException e) {
                        throw new UncheckedIOException(e);
                    }
                });
            }
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    static final class CompletionTracker
    implements Accountable {
        private final AtomicInteger logicalClock = new AtomicInteger();
        private volatile AtomicIntegerArray completionTimes;
        private final StampedLock sl = new StampedLock();

        public CompletionTracker(int initialSize) {
            this.completionTimes = new AtomicIntegerArray(initialSize);
            for (int i = 0; i < initialSize; ++i) {
                this.completionTimes.set(i, Integer.MAX_VALUE);
            }
        }

        void markComplete(int node) {
            long stamp;
            int completionClock = this.logicalClock.getAndIncrement();
            this.ensureCapacity(node);
            do {
                stamp = this.sl.tryOptimisticRead();
                this.completionTimes.set(node, completionClock);
            } while (!this.sl.validate(stamp));
        }

        int clock() {
            return this.logicalClock.get();
        }

        public int completedAt(int node) {
            AtomicIntegerArray ct = this.completionTimes;
            if (node >= ct.length()) {
                return Integer.MAX_VALUE;
            }
            return ct.get(node);
        }

        @Override
        public long ramBytesUsed() {
            int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
            return (long)(REF_BYTES + 4 + REF_BYTES) + 4L * (long)this.completionTimes.length();
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private void ensureCapacity(int node) {
            if (node < this.completionTimes.length()) {
                return;
            }
            long stamp = this.sl.writeLock();
            try {
                AtomicIntegerArray oldArray = this.completionTimes;
                if (node >= oldArray.length()) {
                    int newSize = (node + 1) * 2;
                    AtomicIntegerArray newArr = new AtomicIntegerArray(newSize);
                    for (int i = 0; i < newSize; ++i) {
                        if (i < oldArray.length()) {
                            newArr.set(i, oldArray.get(i));
                            continue;
                        }
                        newArr.set(i, Integer.MAX_VALUE);
                    }
                    this.completionTimes = newArr;
                }
            }
            finally {
                this.sl.unlockWrite(stamp);
            }
        }
    }

    public class ConcurrentGraphIndexView
    extends FrozenView {
        private final int timestamp;

        public ConcurrentGraphIndexView() {
            this.timestamp = OnHeapGraphIndex.this.completions.clock();
        }

        @Override
        public NodesIterator getNeighborsIterator(int level, int node) {
            final NodesIterator it = OnHeapGraphIndex.this.getNeighbors(level, node).iterator();
            return new NodesIterator(){
                int nextNode = this.advance();
                final /* synthetic */ ConcurrentGraphIndexView this$1;
                {
                    this.this$1 = this$1;
                }

                private int advance() {
                    while (it.hasNext()) {
                        int n = it.nextInt();
                        if (this.this$1.OnHeapGraphIndex.this.completions.completedAt(n) >= this.this$1.timestamp) continue;
                        return n;
                    }
                    return Integer.MIN_VALUE;
                }

                @Override
                public int size() {
                    throw new UnsupportedOperationException();
                }

                @Override
                public int nextInt() {
                    int current = this.nextNode;
                    if (current == Integer.MIN_VALUE) {
                        throw new IndexOutOfBoundsException();
                    }
                    this.nextNode = this.advance();
                    return current;
                }

                @Override
                public boolean hasNext() {
                    return this.nextNode != Integer.MIN_VALUE;
                }
            };
        }
    }

    private class FrozenView
    implements GraphIndex.View {
        private FrozenView() {
        }

        @Override
        public NodesIterator getNeighborsIterator(int level, int node) {
            return OnHeapGraphIndex.this.getNeighbors(level, node).iterator();
        }

        @Override
        public int size() {
            return OnHeapGraphIndex.this.size(0);
        }

        @Override
        public GraphIndex.NodeAtLevel entryNode() {
            return OnHeapGraphIndex.this.entryPoint.get();
        }

        @Override
        public Bits liveNodes() {
            return OnHeapGraphIndex.this.deletedNodes.cardinality() == 0 ? Bits.ALL : Bits.inverseOf(OnHeapGraphIndex.this.deletedNodes);
        }

        @Override
        public int getIdUpperBound() {
            return OnHeapGraphIndex.this.getIdUpperBound();
        }

        @Override
        public void close() {
        }

        public String toString() {
            GraphIndex.NodeAtLevel entry = this.entryNode();
            return String.format("%s(size=%d, entryNode=%s)", this.getClass().getSimpleName(), this.size(), entry);
        }
    }
}

