/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.quantization.CompressedVectors;
import io.github.jbellis.jvector.quantization.ImmutableBQVectors;
import io.github.jbellis.jvector.quantization.VectorCompressor;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.function.Supplier;
import java.util.stream.IntStream;

public class BinaryQuantization
implements VectorCompressor<long[]> {
    private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
    private final int dimension;

    public BinaryQuantization(int dimension) {
        this.dimension = dimension;
    }

    @Deprecated
    public static BinaryQuantization compute(RandomAccessVectorValues ravv) {
        return BinaryQuantization.compute(ravv, ForkJoinPool.commonPool());
    }

    @Deprecated
    public static BinaryQuantization compute(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) {
        return new BinaryQuantization(ravv.dimension());
    }

    @Override
    public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
        return new ImmutableBQVectors(this, (long[][])compressedVectors);
    }

    @Override
    public CompressedVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) {
        Supplier<RandomAccessVectorValues> ravvCopy = ravv.threadLocalSupplier();
        long[][] cv = (long[][])((ForkJoinTask)simdExecutor.submit(() -> (long[][])IntStream.range(0, ravv.size()).parallel().mapToObj(arg_0 -> this.lambda$encodeAll$1((Supplier)ravvCopy, arg_0)).toArray(x$0 -> new long[x$0][]))).join();
        return new ImmutableBQVectors(this, cv);
    }

    @Override
    public long[] encode(VectorFloat<?> v) {
        int M = (int)Math.ceil((double)v.length() / 64.0);
        long[] encoded = new long[M];
        this.encodeTo(v, encoded);
        return encoded;
    }

    @Override
    public void encodeTo(VectorFloat<?> v, long[] dest) {
        for (int i = 0; i < dest.length; ++i) {
            int idx;
            long bits = 0L;
            for (int j = 0; j < 64 && (idx = i * 64 + j) < v.length(); ++j) {
                if (!(v.get(idx) > 0.0f)) continue;
                bits |= 1L << j;
            }
            dest[i] = bits;
        }
    }

    @Override
    public int compressorSize() {
        return 4 + this.dimension * 4;
    }

    @Override
    public int compressedVectorSize() {
        int M = (int)Math.ceil((double)this.dimension / 64.0);
        return 8 * M;
    }

    @Override
    public void write(DataOutput out, int version) throws IOException {
        out.writeInt(this.dimension);
        vts.writeFloatVector(out, vts.createFloatVector(this.dimension));
    }

    public int getOriginalDimension() {
        return this.dimension;
    }

    public static BinaryQuantization load(RandomAccessReader in) throws IOException {
        int dimension = in.readInt();
        vts.readFloatVector(in, dimension);
        return new BinaryQuantization(dimension);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        BinaryQuantization that = (BinaryQuantization)o;
        return Objects.equals(this.dimension, that.dimension);
    }

    public int hashCode() {
        return Objects.hashCode(this.dimension);
    }

    public String toString() {
        return "BinaryQuantization";
    }

    @Override
    public double reconstructionError(VectorFloat<?> vector) {
        double sum = 0.0;
        for (int i = 0; i < vector.length(); ++i) {
            boolean bit = vector.get(i) > 0.0f;
            double diff = (bit ? 1.0f : 0.0f) - vector.get(i);
            sum += diff * diff;
        }
        return sum / (double)vector.length();
    }

    private /* synthetic */ long[] lambda$encodeAll$1(Supplier ravvCopy, int i) {
        RandomAccessVectorValues localRavv = (RandomAccessVectorValues)ravvCopy.get();
        VectorFloat<?> v = localRavv.getVector(i);
        return v == null ? new long[this.compressedVectorSize() / 8] : this.encode(v);
    }
}

