/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.graph.ConcurrentNeighborSet;
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.NeighborArray;
import io.github.jbellis.jvector.graph.NeighborSimilarity;
import io.github.jbellis.jvector.graph.OnHeapGraphIndex;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.stream.IntStream;

public class GraphIndexBuilder<T> {
    private final int beamWidth;
    private final ThreadLocal<NeighborArray> naturalScratch;
    private final ThreadLocal<NeighborArray> concurrentScratch;
    private final VectorSimilarityFunction similarityFunction;
    private final float neighborOverflow;
    private final VectorEncoding vectorEncoding;
    private final ThreadLocal<GraphSearcher<?>> graphSearcher;
    final OnHeapGraphIndex<T> graph;
    private final ConcurrentSkipListSet<Integer> insertionsInProgress = new ConcurrentSkipListSet();
    private final ThreadLocal<RandomAccessVectorValues<T>> vectors = ThreadLocal.withInitial(vectorValues::copy);
    private final ThreadLocal<RandomAccessVectorValues<T>> vectorsCopy = ThreadLocal.withInitial(vectorValues::copy);

    public GraphIndexBuilder(RandomAccessVectorValues<T> vectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction, int M, int beamWidth, float neighborOverflow, float alpha) {
        this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
        this.similarityFunction = Objects.requireNonNull(similarityFunction);
        this.neighborOverflow = neighborOverflow;
        if (M <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (beamWidth <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.beamWidth = beamWidth;
        NeighborSimilarity similarity = node1 -> {
            T v1 = this.vectors.get().vectorValue(node1);
            return node2 -> this.scoreBetween(v1, this.vectorsCopy.get().vectorValue(node2));
        };
        this.graph = new OnHeapGraphIndex(M, (node, m) -> new ConcurrentNeighborSet((int)node, (int)m, similarity, alpha));
        this.graphSearcher = ThreadLocal.withInitial(() -> new GraphSearcher.Builder<T>(this.graph.getView()).withConcurrentUpdates().build());
        this.naturalScratch = ThreadLocal.withInitial(() -> new NeighborArray(Math.max(beamWidth, M + 1)));
        this.concurrentScratch = ThreadLocal.withInitial(() -> new NeighborArray(Math.max(beamWidth, M + 1)));
    }

    public OnHeapGraphIndex<T> build() {
        IntStream.range(0, this.vectors.get().size()).parallel().forEach(i -> this.addGraphNode(i, (T)this.vectors.get()));
        this.complete();
        return this.graph;
    }

    public void complete() {
        this.graph.validateEntryNode();
        IntStream.range(0, this.graph.size()).parallel().forEach(i -> this.graph.getNeighbors(i).cleanup());
        this.graph.updateEntryNode(this.approximateMedioid());
        this.graph.validateEntryNode();
    }

    public long addGraphNode(int node, RandomAccessVectorValues<T> values) {
        return this.addGraphNode(node, values.vectorValue(node));
    }

    public OnHeapGraphIndex<T> getGraph() {
        return this.graph;
    }

    public int insertsInProgress() {
        return this.insertionsInProgress.size();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public long addGraphNode(int node, T value) {
        this.graph.addNode(node);
        this.insertionsInProgress.add(node);
        Object inProgressBefore = this.insertionsInProgress.clone();
        try {
            int ep = this.graph.entry();
            GraphSearcher<?> gs = this.graphSearcher.get();
            NeighborSimilarity.ExactScoreFunction scoreFunction = i -> this.scoreBetween(this.vectorsCopy.get().vectorValue(i), value);
            ExcludingBits bits = new ExcludingBits(node);
            SearchResult candidates = gs.searchInternal(scoreFunction, null, this.beamWidth, ep, bits);
            NeighborArray natural = this.getNaturalCandidates(candidates.getNodes());
            NeighborArray concurrent = this.getConcurrentCandidates(node, (Set<Integer>)inProgressBefore);
            this.updateNeighbors(node, natural, concurrent);
            this.graph.markComplete(node);
        }
        finally {
            this.insertionsInProgress.remove(node);
        }
        return this.graph.ramBytesUsedOneNode(0);
    }

    private int approximateMedioid() {
        int newStartNode;
        RandomAccessVectorValues<T> v1 = this.vectors.get();
        RandomAccessVectorValues<T> v2 = this.vectorsCopy.get();
        GraphIndex.View<T> view = this.graph.getView();
        int startNode = view.entryNode();
        while (true) {
            ConcurrentNeighborSet.ConcurrentNeighborArray startNeighbors = this.graph.getNeighbors(startNode).getCurrent();
            newStartNode = IntStream.concat(IntStream.of(startNode), Arrays.stream(startNeighbors.node(), 0, startNeighbors.size)).mapToObj(node -> {
                ConcurrentNeighborSet.ConcurrentNeighborArray nodeNeighbors = this.graph.getNeighbors(node).getCurrent();
                double score = Arrays.stream(nodeNeighbors.node(), 0, nodeNeighbors.size).mapToDouble(i -> this.scoreBetween(v1.vectorValue(node), v2.vectorValue(i))).sum();
                return new AbstractMap.SimpleEntry<Integer, Double>(node, score / (double)v2.size());
            }).min(Comparator.comparingDouble(AbstractMap.SimpleEntry::getValue)).map(AbstractMap.SimpleEntry::getKey).get();
            if (startNode == newStartNode) break;
            startNode = newStartNode;
        }
        return newStartNode;
    }

    private void updateNeighbors(int node, NeighborArray natural, NeighborArray concurrent) {
        ConcurrentNeighborSet neighbors = this.graph.getNeighbors(node);
        neighbors.insertDiverse(natural, concurrent);
        neighbors.backlink(this.graph::getNeighbors, this.neighborOverflow);
    }

    private NeighborArray getNaturalCandidates(SearchResult.NodeScore[] candidates) {
        NeighborArray scratch = this.naturalScratch.get();
        scratch.clear();
        for (SearchResult.NodeScore candidate : candidates) {
            scratch.addInOrder(candidate.node, candidate.score);
        }
        return scratch;
    }

    private NeighborArray getConcurrentCandidates(int newNode, Set<Integer> inProgress) {
        NeighborArray scratch = this.concurrentScratch.get();
        scratch.clear();
        for (Integer n : inProgress) {
            if (n == newNode) continue;
            scratch.insertSorted(n, this.scoreBetween(this.vectors.get().vectorValue(newNode), this.vectorsCopy.get().vectorValue(n)));
        }
        return scratch;
    }

    protected float scoreBetween(T v1, T v2) {
        return GraphIndexBuilder.scoreBetween(this.vectorEncoding, this.similarityFunction, v1, v2);
    }

    static <T> float scoreBetween(VectorEncoding encoding, VectorSimilarityFunction similarityFunction, T v1, T v2) {
        switch (encoding) {
            case BYTE: {
                return similarityFunction.compare((byte[])v1, (byte[])v2);
            }
            case FLOAT32: {
                return similarityFunction.compare((float[])v1, (float[])v2);
            }
        }
        throw new IllegalArgumentException();
    }

    private static class ExcludingBits
    implements Bits {
        private final int excluded;

        public ExcludingBits(int excluded) {
            this.excluded = excluded;
        }

        @Override
        public boolean get(int index) {
            return index != this.excluded;
        }

        @Override
        public int length() {
            throw new UnsupportedOperationException();
        }
    }
}

