/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec.internal;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.elasticsearch.simdvec.QuantizedByteVectorValuesAccess;
import org.elasticsearch.simdvec.internal.Similarities;

public abstract sealed class Int7SQVectorScorerSupplier
implements RandomVectorScorerSupplier,
QuantizedByteVectorValuesAccess {
    static final byte BITS = 7;
    final int dims;
    final int maxOrd;
    final float scoreCorrectionConstant;
    final MemorySegmentAccessInput input;
    final QuantizedByteVectorValues values;
    final ScalarQuantizedVectorSimilarity fallbackScorer;
    static final boolean SUPPORTS_HEAP_SEGMENTS = Runtime.version().feature() >= 22;

    protected Int7SQVectorScorerSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant, ScalarQuantizedVectorSimilarity fallbackScorer) {
        this.input = input;
        this.values = values;
        this.dims = values.dimension();
        this.maxOrd = values.size();
        this.scoreCorrectionConstant = scoreCorrectionConstant;
        this.fallbackScorer = fallbackScorer;
    }

    protected final void checkOrdinal(int ord) {
        if (ord < 0 || ord > this.maxOrd) {
            throw new IllegalArgumentException("illegal ordinal: " + ord);
        }
    }

    final void bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException {
        MemorySegment vectorsSeg = this.input.segmentSliceOrNull(0L, this.input.length());
        if (vectorsSeg == null) {
            for (int i = 0; i < numNodes; ++i) {
                scores[i] = this.scoreFromOrds(firstOrd, ordinals[i]);
            }
        } else {
            int vectorLength = this.dims;
            int vectorPitch = vectorLength + 4;
            if (SUPPORTS_HEAP_SEGMENTS) {
                MemorySegment ordinalsSeg = MemorySegment.ofArray(ordinals);
                MemorySegment scoresSeg = MemorySegment.ofArray(scores);
                this.bulkScoreFromSegment(vectorsSeg, vectorLength, vectorPitch, firstOrd, ordinalsSeg, scoresSeg, numNodes);
            } else {
                try (Arena arena = Arena.ofConfined();){
                    MemorySegment ordinalsMemorySegment = arena.allocate((long)numNodes * 4L, 32L);
                    MemorySegment scoresMemorySegment = arena.allocate((long)numNodes * 4L, 32L);
                    MemorySegment.copy(ordinals, 0, ordinalsMemorySegment, ValueLayout.JAVA_INT, 0L, numNodes);
                    this.bulkScoreFromSegment(vectorsSeg, vectorLength, vectorPitch, firstOrd, ordinalsMemorySegment, scoresMemorySegment, numNodes);
                    MemorySegment.copy(scoresMemorySegment, ValueLayout.JAVA_FLOAT, 0L, scores, 0, numNodes);
                }
            }
        }
    }

    final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException {
        int length = this.dims;
        long firstByteOffset = (long)firstOrd * (long)(length + 4);
        long secondByteOffset = (long)secondOrd * (long)(length + 4);
        MemorySegment firstSeg = this.input.segmentSliceOrNull(firstByteOffset, (long)length);
        if (firstSeg == null) {
            return this.fallbackScore(firstByteOffset, secondByteOffset);
        }
        float firstOffset = Float.intBitsToFloat(this.input.readInt(firstByteOffset + (long)length));
        MemorySegment secondSeg = this.input.segmentSliceOrNull(secondByteOffset, (long)length);
        if (secondSeg == null) {
            return this.fallbackScore(firstByteOffset, secondByteOffset);
        }
        float secondOffset = Float.intBitsToFloat(this.input.readInt(secondByteOffset + (long)length));
        return this.scoreFromSegments(firstSeg, firstOffset, secondSeg, secondOffset);
    }

    abstract float scoreFromSegments(MemorySegment var1, float var2, MemorySegment var3, float var4);

    protected void bulkScoreFromSegment(MemorySegment vectors, int vectorLength, int vectorPitch, int firstOrd, MemorySegment ordinals, MemorySegment scores, int numNodes) {
        long firstByteOffset = (long)firstOrd * (long)vectorPitch;
        MemorySegment a = vectors.asSlice(firstByteOffset, vectorLength);
        float aOffset = Float.intBitsToFloat(vectors.asSlice(firstByteOffset + (long)vectorLength, 4L).getAtIndex(ValueLayout.JAVA_INT_UNALIGNED, 0L));
        for (int i = 0; i < numNodes; ++i) {
            int secondOrd = ordinals.getAtIndex(ValueLayout.JAVA_INT, (long)i);
            long secondByteOffset = (long)secondOrd * (long)vectorPitch;
            MemorySegment b = vectors.asSlice(secondByteOffset, vectorLength);
            float bOffset = Float.intBitsToFloat(vectors.asSlice(secondByteOffset + (long)vectorLength, 4L).getAtIndex(ValueLayout.JAVA_INT_UNALIGNED, 0L));
            float score = this.scoreFromSegments(a, aOffset, b, bOffset);
            scores.setAtIndex(ValueLayout.JAVA_FLOAT, (long)i, score);
        }
    }

    private float fallbackScore(long firstByteOffset, long secondByteOffset) throws IOException {
        byte[] a = new byte[this.dims];
        this.input.readBytes(firstByteOffset, a, 0, a.length);
        float aOffsetValue = Float.intBitsToFloat(this.input.readInt(firstByteOffset + (long)this.dims));
        byte[] b = new byte[this.dims];
        this.input.readBytes(secondByteOffset, b, 0, a.length);
        float bOffsetValue = Float.intBitsToFloat(this.input.readInt(secondByteOffset + (long)this.dims));
        return this.fallbackScorer.score(a, aOffsetValue, b, bOffsetValue);
    }

    public UpdateableRandomVectorScorer scorer() {
        return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer((KnnVectorValues)this.values){
            private int ord;
            {
                this.ord = -1;
            }

            public float score(int node) throws IOException {
                Int7SQVectorScorerSupplier.this.checkOrdinal(node);
                return Int7SQVectorScorerSupplier.this.scoreFromOrds(this.ord, node);
            }

            public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
                Int7SQVectorScorerSupplier.this.bulkScoreFromOrds(this.ord, nodes, scores, numNodes);
            }

            public void setScoringOrdinal(int node) throws IOException {
                Int7SQVectorScorerSupplier.this.checkOrdinal(node);
                this.ord = node;
            }
        };
    }

    @Override
    public QuantizedByteVectorValues get() {
        return this.values;
    }

    static boolean checkIndex(long index, long length) {
        return index >= 0L && index < length;
    }

    public static final class MaxInnerProductSupplier
    extends Int7SQVectorScorerSupplier {
        public MaxInnerProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) {
            super(input, values, scoreCorrectionConstant, ScalarQuantizedVectorSimilarity.fromVectorSimilarity((VectorSimilarityFunction)VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, (float)scoreCorrectionConstant, (byte)7));
        }

        @Override
        float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset) {
            int dotProduct = Similarities.dotProduct7u(a, b, this.dims);
            assert (dotProduct >= 0);
            float adjustedDistance = (float)dotProduct * this.scoreCorrectionConstant + aOffset + bOffset;
            if (adjustedDistance < 0.0f) {
                return 1.0f / (1.0f + -1.0f * adjustedDistance);
            }
            return adjustedDistance + 1.0f;
        }

        @Override
        protected void bulkScoreFromSegment(MemorySegment vectors, int vectorLength, int vectorPitch, int firstOrd, MemorySegment ordinals, MemorySegment scores, int numNodes) {
            long firstByteOffset = (long)firstOrd * (long)vectorPitch;
            MemorySegment firstVector = vectors.asSlice(firstByteOffset, vectorPitch);
            Similarities.dotProduct7uBulkWithOffsets(vectors, firstVector, this.dims, vectorPitch, ordinals, numNodes, scores);
            float aOffset = Float.intBitsToFloat(vectors.asSlice(firstByteOffset + (long)vectorLength, 4L).getAtIndex(ValueLayout.JAVA_INT_UNALIGNED, 0L));
            for (int i = 0; i < numNodes; ++i) {
                int secondOrd;
                long secondByteOffset;
                float bOffset;
                float dotProduct = scores.getAtIndex(ValueLayout.JAVA_FLOAT, (long)i);
                float adjustedDistance = dotProduct * this.scoreCorrectionConstant + aOffset + (bOffset = Float.intBitsToFloat(vectors.asSlice((secondByteOffset = (long)(secondOrd = ordinals.getAtIndex(ValueLayout.JAVA_INT, (long)i)) * (long)vectorPitch) + (long)vectorLength, 4L).getAtIndex(ValueLayout.JAVA_INT_UNALIGNED, 0L)));
                adjustedDistance = adjustedDistance < 0.0f ? 1.0f / (1.0f + -1.0f * adjustedDistance) : adjustedDistance + 1.0f;
                scores.setAtIndex(ValueLayout.JAVA_FLOAT, (long)i, adjustedDistance);
            }
        }

        public MaxInnerProductSupplier copy() {
            return new MaxInnerProductSupplier(this.input.clone(), this.values, this.scoreCorrectionConstant);
        }
    }

    public static final class DotProductSupplier
    extends Int7SQVectorScorerSupplier {
        public DotProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) {
            super(input, values, scoreCorrectionConstant, ScalarQuantizedVectorSimilarity.fromVectorSimilarity((VectorSimilarityFunction)VectorSimilarityFunction.DOT_PRODUCT, (float)scoreCorrectionConstant, (byte)7));
        }

        @Override
        float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset) {
            int dotProduct = Similarities.dotProduct7u(a, b, this.dims);
            assert (dotProduct >= 0);
            float adjustedDistance = (float)dotProduct * this.scoreCorrectionConstant + aOffset + bOffset;
            return Math.max((1.0f + adjustedDistance) / 2.0f, 0.0f);
        }

        @Override
        protected void bulkScoreFromSegment(MemorySegment vectors, int vectorLength, int vectorPitch, int firstOrd, MemorySegment ordinals, MemorySegment scores, int numNodes) {
            long firstByteOffset = (long)firstOrd * (long)vectorPitch;
            MemorySegment firstVector = vectors.asSlice(firstByteOffset, vectorPitch);
            Similarities.dotProduct7uBulkWithOffsets(vectors, firstVector, this.dims, vectorPitch, ordinals, numNodes, scores);
            float aOffset = Float.intBitsToFloat(vectors.asSlice(firstByteOffset + (long)vectorLength, 4L).getAtIndex(ValueLayout.JAVA_INT_UNALIGNED, 0L));
            for (int i = 0; i < numNodes; ++i) {
                float dotProduct = scores.getAtIndex(ValueLayout.JAVA_FLOAT, (long)i);
                int secondOrd = ordinals.getAtIndex(ValueLayout.JAVA_INT, (long)i);
                long secondByteOffset = (long)secondOrd * (long)vectorPitch;
                float bOffset = Float.intBitsToFloat(vectors.asSlice(secondByteOffset + (long)vectorLength, 4L).getAtIndex(ValueLayout.JAVA_INT_UNALIGNED, 0L));
                float adjustedDistance = dotProduct * this.scoreCorrectionConstant + aOffset + bOffset;
                scores.setAtIndex(ValueLayout.JAVA_FLOAT, (long)i, Math.max((1.0f + adjustedDistance) / 2.0f, 0.0f));
            }
        }

        public DotProductSupplier copy() {
            return new DotProductSupplier(this.input.clone(), this.values, this.scoreCorrectionConstant);
        }
    }

    public static final class EuclideanSupplier
    extends Int7SQVectorScorerSupplier {
        public EuclideanSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) {
            super(input, values, scoreCorrectionConstant, ScalarQuantizedVectorSimilarity.fromVectorSimilarity((VectorSimilarityFunction)VectorSimilarityFunction.EUCLIDEAN, (float)scoreCorrectionConstant, (byte)7));
        }

        @Override
        float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset) {
            int squareDistance = Similarities.squareDistance7u(a, b, this.dims);
            float adjustedDistance = (float)squareDistance * this.scoreCorrectionConstant;
            return 1.0f / (1.0f + adjustedDistance);
        }

        public EuclideanSupplier copy() {
            return new EuclideanSupplier(this.input.clone(), this.values, this.scoreCorrectionConstant);
        }
    }
}

