/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph.disk;

import io.github.jbellis.jvector.disk.RandomAccessWriter;
import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
import io.github.jbellis.jvector.graph.disk.NodeRecordTask;
import io.github.jbellis.jvector.graph.disk.OrdinalMapper;
import io.github.jbellis.jvector.graph.disk.feature.Feature;
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.AsynchronousFileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntFunction;

class ParallelGraphWriter
implements AutoCloseable {
    private final RandomAccessWriter writer;
    private final ImmutableGraphIndex graph;
    private final ExecutorService executor;
    private final ThreadLocal<ImmutableGraphIndex.View> viewPerThread;
    private final ThreadLocal<ByteBuffer> bufferPerThread;
    private final CopyOnWriteArrayList<ImmutableGraphIndex.View> allViews = new CopyOnWriteArrayList();
    private final int recordSize;
    private final Path filePath;
    private final int taskMultiplier;
    private static final AtomicInteger threadCounter = new AtomicInteger(0);

    public ParallelGraphWriter(RandomAccessWriter writer, ImmutableGraphIndex graph, List<Feature> inlineFeatures, Config config, Path filePath) {
        this.writer = writer;
        this.graph = graph;
        this.filePath = Objects.requireNonNull(filePath);
        this.taskMultiplier = config.taskMultiplier;
        this.executor = Executors.newFixedThreadPool(config.workerThreads, r -> {
            Thread t = new Thread(r);
            t.setName("ParallelGraphWriter-Worker-" + threadCounter.getAndIncrement());
            t.setDaemon(false);
            return t;
        });
        this.recordSize = 4 + inlineFeatures.stream().mapToInt(Feature::featureSize).sum() + 4 + graph.getDegree(0) * 4;
        this.viewPerThread = ThreadLocal.withInitial(() -> {
            ImmutableGraphIndex.View view = graph.getView();
            this.allViews.add(view);
            return view;
        });
        int bufferSize = this.recordSize;
        boolean useDirect = config.useDirectBuffers;
        this.bufferPerThread = ThreadLocal.withInitial(() -> {
            ByteBuffer buffer = useDirect ? ByteBuffer.allocateDirect(bufferSize) : ByteBuffer.allocate(bufferSize);
            buffer.order(ByteOrder.BIG_ENDIAN);
            return buffer;
        });
    }

    public void writeL0Records(OrdinalMapper ordinalMapper, List<Feature> inlineFeatures, Map<FeatureId, IntFunction<Feature.State>> featureStateSuppliers, long baseOffset) throws IOException {
        int maxOrdinal = ordinalMapper.maxOrdinal();
        int totalOrdinals = maxOrdinal + 1;
        int numCores = Runtime.getRuntime().availableProcessors();
        int numTasks = Math.min(totalOrdinals / (numCores * this.taskMultiplier), totalOrdinals);
        int ordinalsPerTask = (totalOrdinals + numTasks - 1) / numTasks;
        ArrayList<Future<List<NodeRecordTask.Result>>> futures = new ArrayList<Future<List<NodeRecordTask.Result>>>(numTasks);
        for (int i = 0; i < numTasks; ++i) {
            int startOrdinal = i * ordinalsPerTask;
            int endOrdinal = Math.min(startOrdinal + ordinalsPerTask, totalOrdinals);
            if (startOrdinal >= totalOrdinals) break;
            int start = startOrdinal;
            int end = endOrdinal;
            Future<List> future = this.executor.submit(() -> {
                ImmutableGraphIndex.View view = this.viewPerThread.get();
                ByteBuffer buffer = this.bufferPerThread.get();
                NodeRecordTask task = new NodeRecordTask(start, end, ordinalMapper, this.graph, view, inlineFeatures, featureStateSuppliers, this.recordSize, baseOffset, buffer);
                return task.call();
            });
            futures.add(future);
        }
        this.writeRecordsAsync(futures);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void writeRecordsAsync(List<Future<List<NodeRecordTask.Result>>> futures) throws IOException {
        EnumSet<StandardOpenOption> opts = EnumSet.of(StandardOpenOption.WRITE, StandardOpenOption.READ);
        int numThreads = Math.min(Runtime.getRuntime().availableProcessors(), 32);
        ExecutorService fileWritePool = null;
        try {
            fileWritePool = new ThreadPoolExecutor(numThreads, numThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(), r -> {
                Thread t = new Thread(r, "graphnode-writer");
                t.setDaemon(true);
                return t;
            });
            int maxConcurrentWrites = numThreads * 2;
            ArrayList<Future<Integer>> pendingWrites = new ArrayList<Future<Integer>>(maxConcurrentWrites);
            try (AsynchronousFileChannel afc = AsynchronousFileChannel.open(this.filePath, opts, fileWritePool, new FileAttribute[0]);){
                for (Future<List<NodeRecordTask.Result>> future : futures) {
                    List<NodeRecordTask.Result> results = future.get();
                    for (NodeRecordTask.Result result : results) {
                        Future<Integer> writeFuture = afc.write(result.data, result.fileOffset);
                        pendingWrites.add(writeFuture);
                        if (pendingWrites.size() < maxConcurrentWrites) continue;
                        for (Future future2 : pendingWrites) {
                            future2.get();
                        }
                        pendingWrites.clear();
                    }
                }
                for (Future<List<NodeRecordTask.Result>> future : pendingWrites) {
                    future.get();
                }
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new IOException("Interrupted while writing records", e);
            }
            catch (ExecutionException e) {
                throw this.unwrapExecutionException(e);
            }
        }
        finally {
            if (fileWritePool != null) {
                fileWritePool.shutdown();
                try {
                    if (!fileWritePool.awaitTermination(60L, TimeUnit.SECONDS)) {
                        fileWritePool.shutdownNow();
                    }
                }
                catch (InterruptedException e) {
                    fileWritePool.shutdownNow();
                    Thread.currentThread().interrupt();
                }
            }
        }
    }

    private IOException unwrapExecutionException(ExecutionException e) {
        Throwable cause = e.getCause();
        if (cause instanceof IOException) {
            return (IOException)cause;
        }
        if (cause instanceof RuntimeException) {
            throw (RuntimeException)cause;
        }
        throw new RuntimeException("Error building node record", cause);
    }

    public int getRecordSize() {
        return this.recordSize;
    }

    @Override
    public void close() throws IOException {
        try {
            this.executor.shutdown();
            try {
                if (!this.executor.awaitTermination(60L, TimeUnit.SECONDS)) {
                    this.executor.shutdownNow();
                }
            }
            catch (InterruptedException e) {
                this.executor.shutdownNow();
                Thread.currentThread().interrupt();
            }
            for (ImmutableGraphIndex.View view : this.allViews) {
                view.close();
            }
            this.allViews.clear();
        }
        catch (IOException e) {
            throw e;
        }
        catch (Exception e) {
            throw new IOException("Error closing parallel writer", e);
        }
    }

    static class Config {
        final int workerThreads;
        final boolean useDirectBuffers;
        final int taskMultiplier;

        public Config(int workerThreads, boolean useDirectBuffers, int taskMultiplier) {
            this.workerThreads = workerThreads <= 0 ? Runtime.getRuntime().availableProcessors() : workerThreads;
            this.useDirectBuffers = useDirectBuffers;
            this.taskMultiplier = taskMultiplier <= 0 ? 4 : taskMultiplier;
        }

        public static Config defaultConfig() {
            return new Config(0, false, 4);
        }
    }
}

