/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.memory;

import com.google.common.collect.ImmutableList;
import io.trino.plugin.memory.MemoryConfig;
import io.trino.plugin.memory.MemoryErrorCode;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.Page;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.OptionalDouble;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

@ThreadSafe
public class MemoryPagesStore {
    private final long maxBytes;
    @GuardedBy(value="this")
    private long currentBytes;
    private final Map<Long, TableData> tables = new HashMap<Long, TableData>();

    @Inject
    public MemoryPagesStore(MemoryConfig config) {
        this.maxBytes = config.getMaxDataPerNode().toBytes();
    }

    public synchronized void initialize(long tableId) {
        if (!this.tables.containsKey(tableId)) {
            this.tables.put(tableId, new TableData());
        }
    }

    public synchronized void add(Long tableId, Page page) {
        if (!this.contains(tableId)) {
            throw new TrinoException((ErrorCodeSupplier)MemoryErrorCode.MISSING_DATA, "Failed to find table on a worker.");
        }
        page.compact();
        long newSize = this.currentBytes + page.getRetainedSizeInBytes();
        if (this.maxBytes < newSize) {
            throw new TrinoException((ErrorCodeSupplier)MemoryErrorCode.MEMORY_LIMIT_EXCEEDED, String.format("Memory limit [%d] for memory connector exceeded", this.maxBytes));
        }
        this.currentBytes = newSize;
        TableData tableData = this.tables.get(tableId);
        tableData.add(page);
    }

    public synchronized List<Page> getPages(Long tableId, int partNumber, int totalParts, List<Integer> columnIndexes, long expectedRows, OptionalLong limit, OptionalDouble sampleRatio) {
        if (!this.contains(tableId)) {
            throw new TrinoException((ErrorCodeSupplier)MemoryErrorCode.MISSING_DATA, "Failed to find table on a worker.");
        }
        TableData tableData = this.tables.get(tableId);
        if (tableData.getRows() < expectedRows) {
            throw new TrinoException((ErrorCodeSupplier)MemoryErrorCode.MISSING_DATA, String.format("Expected to find [%s] rows on a worker, but found [%s].", expectedRows, tableData.getRows()));
        }
        ImmutableList.Builder partitionedPages = ImmutableList.builder();
        boolean done = false;
        long totalRows = 0L;
        for (int i = partNumber; i < tableData.getPages().size() && !done; i += totalParts) {
            if (sampleRatio.isPresent() && ThreadLocalRandom.current().nextDouble() >= sampleRatio.getAsDouble()) continue;
            Page page = tableData.getPages().get(i);
            if (limit.isPresent() && (totalRows += (long)page.getPositionCount()) > limit.getAsLong()) {
                page = page.getRegion(0, (int)((long)page.getPositionCount() - (totalRows - limit.getAsLong())));
                done = true;
            }
            partitionedPages.add((Object)MemoryPagesStore.getColumns(page, columnIndexes));
        }
        return partitionedPages.build();
    }

    public synchronized boolean contains(Long tableId) {
        return this.tables.containsKey(tableId);
    }

    public synchronized void cleanUp(Set<Long> activeTableIds) {
        if (activeTableIds.isEmpty()) {
            return;
        }
        long latestTableId = Collections.max(activeTableIds);
        Iterator<Map.Entry<Long, TableData>> tableDataIterator = this.tables.entrySet().iterator();
        while (tableDataIterator.hasNext()) {
            Map.Entry<Long, TableData> tablePagesEntry = tableDataIterator.next();
            Long tableId = tablePagesEntry.getKey();
            if (tableId >= latestTableId || activeTableIds.contains(tableId)) continue;
            for (Page removedPage : tablePagesEntry.getValue().getPages()) {
                this.currentBytes -= removedPage.getRetainedSizeInBytes();
            }
            tableDataIterator.remove();
        }
    }

    private static Page getColumns(Page page, List<Integer> columnIndexes) {
        Block[] outputBlocks = new Block[columnIndexes.size()];
        for (int i = 0; i < columnIndexes.size(); ++i) {
            outputBlocks[i] = page.getBlock(columnIndexes.get(i).intValue());
        }
        return new Page(page.getPositionCount(), outputBlocks);
    }

    private static final class TableData {
        private final List<Page> pages = new ArrayList<Page>();
        private long rows;

        private TableData() {
        }

        public void add(Page page) {
            this.pages.add(page);
            this.rows += (long)page.getPositionCount();
        }

        private List<Page> getPages() {
            return this.pages;
        }

        private long getRows() {
            return this.rows;
        }
    }
}

