/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler.faulttolerant;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Table;
import com.google.common.math.Quantiles;
import com.google.common.math.Stats;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.inject.Inject;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.annotation.NotThreadSafe;
import io.trino.execution.QueryManagerConfig;
import io.trino.execution.StageId;
import io.trino.execution.scheduler.faulttolerant.TaskDescriptor;
import io.trino.metadata.Split;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.sql.planner.plan.PlanNodeId;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

public class TaskDescriptorStorage {
    private static final Logger log = Logger.get(TaskDescriptorStorage.class);
    private final long maxMemoryInBytes;
    private final JsonCodec<Split> splitJsonCodec;
    private final StorageStats storageStats;
    @GuardedBy(value="this")
    private final Map<QueryId, TaskDescriptors> storages = new HashMap<QueryId, TaskDescriptors>();
    @GuardedBy(value="this")
    private long reservedBytes;

    @Inject
    public TaskDescriptorStorage(QueryManagerConfig config, JsonCodec<Split> splitJsonCodec) {
        this(config.getFaultTolerantExecutionTaskDescriptorStorageMaxMemory(), splitJsonCodec);
    }

    public TaskDescriptorStorage(DataSize maxMemory, JsonCodec<Split> splitJsonCodec) {
        this.maxMemoryInBytes = maxMemory.toBytes();
        this.splitJsonCodec = Objects.requireNonNull(splitJsonCodec, "splitJsonCodec is null");
        this.storageStats = new StorageStats((Supplier<StorageStatsValue>)Suppliers.memoizeWithExpiration(this::computeStats, (long)1L, (TimeUnit)TimeUnit.SECONDS));
    }

    public synchronized void initialize(QueryId queryId) {
        TaskDescriptors storage = new TaskDescriptors();
        Verify.verify((this.storages.putIfAbsent(queryId, storage) == null ? 1 : 0) != 0, (String)"storage is already initialized for query: %s", (Object)queryId);
        this.updateMemoryReservation(storage.getReservedBytes());
    }

    public synchronized void put(StageId stageId, TaskDescriptor descriptor) {
        TaskDescriptors storage = this.storages.get(stageId.getQueryId());
        if (storage == null) {
            return;
        }
        long previousReservedBytes = storage.getReservedBytes();
        storage.put(stageId, descriptor.getPartitionId(), descriptor);
        long currentReservedBytes = storage.getReservedBytes();
        long delta = currentReservedBytes - previousReservedBytes;
        this.updateMemoryReservation(delta);
    }

    public synchronized Optional<TaskDescriptor> get(StageId stageId, int partitionId) {
        TaskDescriptors storage = this.storages.get(stageId.getQueryId());
        if (storage == null) {
            return Optional.empty();
        }
        return Optional.of(storage.get(stageId, partitionId));
    }

    public synchronized void remove(StageId stageId, int partitionId) {
        TaskDescriptors storage = this.storages.get(stageId.getQueryId());
        if (storage == null) {
            return;
        }
        long previousReservedBytes = storage.getReservedBytes();
        storage.remove(stageId, partitionId);
        long currentReservedBytes = storage.getReservedBytes();
        long delta = currentReservedBytes - previousReservedBytes;
        this.updateMemoryReservation(delta);
    }

    public synchronized void destroy(QueryId queryId) {
        TaskDescriptors storage = this.storages.remove(queryId);
        if (storage != null) {
            this.updateMemoryReservation(-storage.getReservedBytes());
        }
    }

    private synchronized void updateMemoryReservation(long delta) {
        this.reservedBytes += delta;
        if (delta <= 0L) {
            return;
        }
        while (this.reservedBytes > this.maxMemoryInBytes) {
            QueryId killCandidate = this.storages.entrySet().stream().max(Comparator.comparingLong(entry -> ((TaskDescriptors)entry.getValue()).getReservedBytes())).map(Map.Entry::getKey).orElseThrow(() -> new VerifyException(String.format("storage is empty but reservedBytes (%s) is still greater than maxMemoryInBytes (%s)", this.reservedBytes, this.maxMemoryInBytes)));
            TaskDescriptors storage = this.storages.get(killCandidate);
            long previousReservedBytes = storage.getReservedBytes();
            if (log.isInfoEnabled()) {
                log.info("Failing query %s; reclaiming %s of %s task descriptor memory from %s queries; extraStorageInfo=%s", new Object[]{killCandidate, storage.getReservedBytes(), DataSize.succinctBytes((long)this.reservedBytes), this.storages.size(), storage.getDebugInfo()});
            }
            storage.fail((RuntimeException)((Object)new TrinoException((ErrorCodeSupplier)StandardErrorCode.EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY, String.format("Task descriptor storage capacity has been exceeded: %s > %s", DataSize.succinctBytes((long)this.maxMemoryInBytes), DataSize.succinctBytes((long)this.reservedBytes)))));
            long currentReservedBytes = storage.getReservedBytes();
            this.reservedBytes += currentReservedBytes - previousReservedBytes;
        }
    }

    @VisibleForTesting
    synchronized long getReservedBytes() {
        return this.reservedBytes;
    }

    @Managed
    @Nested
    public StorageStats getStats() {
        return this.storageStats;
    }

    private synchronized StorageStatsValue computeStats() {
        int queriesCount = this.storages.size();
        long stagesCount = this.storages.values().stream().mapToLong(TaskDescriptors::getStagesCount).sum();
        Quantiles.ScaleAndIndexes percentiles = Quantiles.percentiles().indexes(new int[]{50, 90, 95});
        long queryReservedBytesP50 = 0L;
        long queryReservedBytesP90 = 0L;
        long queryReservedBytesP95 = 0L;
        long queryReservedBytesAvg = 0L;
        long stageReservedBytesP50 = 0L;
        long stageReservedBytesP90 = 0L;
        long stageReservedBytesP95 = 0L;
        long stageReservedBytesAvg = 0L;
        if (queriesCount > 0) {
            Map queryReservedBytesPercentiles = percentiles.compute((Collection)this.storages.values().stream().map(TaskDescriptors::getReservedBytes).collect(ImmutableList.toImmutableList()));
            queryReservedBytesP50 = ((Double)queryReservedBytesPercentiles.get(50)).longValue();
            queryReservedBytesP90 = ((Double)queryReservedBytesPercentiles.get(90)).longValue();
            queryReservedBytesP95 = ((Double)queryReservedBytesPercentiles.get(95)).longValue();
            queryReservedBytesAvg = this.reservedBytes / (long)queriesCount;
            List storagesReservedBytes = (List)this.storages.values().stream().flatMap(TaskDescriptors::getStagesReservedBytes).collect(ImmutableList.toImmutableList());
            if (!storagesReservedBytes.isEmpty()) {
                Map stagesReservedBytesPercentiles = percentiles.compute((Collection)storagesReservedBytes);
                stageReservedBytesP50 = ((Double)stagesReservedBytesPercentiles.get(50)).longValue();
                stageReservedBytesP90 = ((Double)stagesReservedBytesPercentiles.get(90)).longValue();
                stageReservedBytesP95 = ((Double)stagesReservedBytesPercentiles.get(95)).longValue();
                stageReservedBytesAvg = this.reservedBytes / stagesCount;
            }
        }
        return new StorageStatsValue(queriesCount, stagesCount, this.reservedBytes, queryReservedBytesAvg, queryReservedBytesP50, queryReservedBytesP90, queryReservedBytesP95, stageReservedBytesAvg, stageReservedBytesP50, stageReservedBytesP90, stageReservedBytesP95);
    }

    public static class StorageStats {
        private final Supplier<StorageStatsValue> statsSupplier;

        StorageStats(Supplier<StorageStatsValue> statsSupplier) {
            this.statsSupplier = Objects.requireNonNull(statsSupplier, "statsSupplier is null");
        }

        @Managed
        public long getQueriesCount() {
            return this.statsSupplier.get().queriesCount();
        }

        @Managed
        public long getStagesCount() {
            return this.statsSupplier.get().stagesCount();
        }

        @Managed
        public long getReservedBytes() {
            return this.statsSupplier.get().reservedBytes();
        }

        @Managed
        public long getQueryReservedBytesAvg() {
            return this.statsSupplier.get().queryReservedBytesAvg();
        }

        @Managed
        public long getQueryReservedBytesP50() {
            return this.statsSupplier.get().queryReservedBytesP50();
        }

        @Managed
        public long getQueryReservedBytesP90() {
            return this.statsSupplier.get().queryReservedBytesP90();
        }

        @Managed
        public long getQueryReservedBytesP95() {
            return this.statsSupplier.get().queryReservedBytesP95();
        }

        @Managed
        public long getStageReservedBytesAvg() {
            return this.statsSupplier.get().stageReservedBytesP50();
        }

        @Managed
        public long getStageReservedBytesP50() {
            return this.statsSupplier.get().stageReservedBytesP50();
        }

        @Managed
        public long getStageReservedBytesP90() {
            return this.statsSupplier.get().stageReservedBytesP90();
        }

        @Managed
        public long getStageReservedBytesP95() {
            return this.statsSupplier.get().stageReservedBytesP95();
        }
    }

    @NotThreadSafe
    private class TaskDescriptors {
        private final Table<StageId, Integer, TaskDescriptor> descriptors = HashBasedTable.create();
        private long reservedBytes;
        private final Map<StageId, AtomicLong> stagesReservedBytes = new HashMap<StageId, AtomicLong>();
        private RuntimeException failure;

        private TaskDescriptors() {
        }

        public void put(StageId stageId2, int partitionId, TaskDescriptor descriptor) {
            this.throwIfFailed();
            Preconditions.checkState((!this.descriptors.contains((Object)stageId2, (Object)partitionId) ? 1 : 0) != 0, (String)"task descriptor is already present for key %s/%s ", (Object)stageId2, (int)partitionId);
            this.descriptors.put((Object)stageId2, (Object)partitionId, (Object)descriptor);
            long descriptorRetainedBytes = descriptor.getRetainedSizeInBytes();
            this.reservedBytes += descriptorRetainedBytes;
            this.stagesReservedBytes.computeIfAbsent(stageId2, stageId -> new AtomicLong()).addAndGet(descriptorRetainedBytes);
        }

        public TaskDescriptor get(StageId stageId, int partitionId) {
            this.throwIfFailed();
            TaskDescriptor descriptor = (TaskDescriptor)this.descriptors.get((Object)stageId, (Object)partitionId);
            if (descriptor == null) {
                throw new NoSuchElementException(String.format("descriptor not found for key %s/%s", stageId, partitionId));
            }
            return descriptor;
        }

        public void remove(StageId stageId, int partitionId) {
            this.throwIfFailed();
            TaskDescriptor descriptor = (TaskDescriptor)this.descriptors.remove((Object)stageId, (Object)partitionId);
            if (descriptor == null) {
                throw new NoSuchElementException(String.format("descriptor not found for key %s/%s", stageId, partitionId));
            }
            long descriptorRetainedBytes = descriptor.getRetainedSizeInBytes();
            this.reservedBytes -= descriptorRetainedBytes;
            Objects.requireNonNull(this.stagesReservedBytes.get(stageId), () -> String.format("no entry for stage %s", stageId)).addAndGet(-descriptorRetainedBytes);
        }

        public long getReservedBytes() {
            return this.reservedBytes;
        }

        private String getDebugInfo() {
            Multimap descriptorsByStageId = (Multimap)this.descriptors.cellSet().stream().collect(ImmutableSetMultimap.toImmutableSetMultimap(Table.Cell::getRowKey, Table.Cell::getValue));
            Map debugInfoByStageId = (Map)descriptorsByStageId.asMap().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> this.getDebugInfo((Collection)entry.getValue())));
            List<String> biggestSplits = descriptorsByStageId.entries().stream().flatMap(entry -> ((TaskDescriptor)entry.getValue()).getSplits().getSplitsFlat().entries().stream().map(splitEntry -> Map.entry("%s/%s".formatted(entry.getKey(), splitEntry.getKey()), (Split)splitEntry.getValue()))).sorted(Comparator.comparingLong(entry -> ((Split)entry.getValue()).getRetainedSizeInBytes()).reversed()).limit(3L).map(entry -> "{nodeId=%s, size=%s, split=%s}".formatted(entry.getKey(), ((Split)entry.getValue()).getRetainedSizeInBytes(), TaskDescriptorStorage.this.splitJsonCodec.toJson((Object)((Split)entry.getValue())))).toList();
            return "stagesInfo=%s; biggestSplits=%s".formatted(debugInfoByStageId, biggestSplits);
        }

        private String getDebugInfo(Collection<TaskDescriptor> taskDescriptors) {
            int taskDescriptorsCount = taskDescriptors.size();
            Stats taskDescriptorsRetainedSizeStats = Stats.of((LongStream)taskDescriptors.stream().mapToLong(TaskDescriptor::getRetainedSizeInBytes));
            Set planNodeIds = (Set)taskDescriptors.stream().flatMap(taskDescriptor -> taskDescriptor.getSplits().getSplitsFlat().keySet().stream()).collect(ImmutableSet.toImmutableSet());
            HashMap<PlanNodeId, String> splitsDebugInfo = new HashMap<PlanNodeId, String>();
            for (PlanNodeId planNodeId : planNodeIds) {
                Stats splitCountStats = Stats.of((LongStream)taskDescriptors.stream().mapToLong(taskDescriptor -> ((Collection)taskDescriptor.getSplits().getSplitsFlat().asMap().get(planNodeId)).size()));
                Stats splitSizeStats = Stats.of((LongStream)taskDescriptors.stream().flatMap(taskDescriptor -> taskDescriptor.getSplits().getSplitsFlat().get((Object)planNodeId).stream()).mapToLong(Split::getRetainedSizeInBytes));
                splitsDebugInfo.put(planNodeId, "{splitCountMean=%s, splitCountStdDev=%s, splitSizeMean=%s, splitSizeStdDev=%s}".formatted(splitCountStats.mean(), splitCountStats.populationStandardDeviation(), splitSizeStats.mean(), splitSizeStats.populationStandardDeviation()));
            }
            return "[taskDescriptorsCount=%s, taskDescriptorsRetainedSizeMean=%s, taskDescriptorsRetainedSizeStdDev=%s, splits=%s]".formatted(taskDescriptorsCount, taskDescriptorsRetainedSizeStats.mean(), taskDescriptorsRetainedSizeStats.populationStandardDeviation(), splitsDebugInfo);
        }

        private void fail(RuntimeException failure) {
            if (this.failure == null) {
                this.descriptors.clear();
                this.reservedBytes = 0L;
                this.failure = failure;
            }
        }

        private void throwIfFailed() {
            if (this.failure != null) {
                throw this.failure;
            }
        }

        public int getStagesCount() {
            return this.descriptors.rowMap().size();
        }

        public Stream<Long> getStagesReservedBytes() {
            return this.stagesReservedBytes.values().stream().map(AtomicLong::get);
        }
    }

    private record StorageStatsValue(long queriesCount, long stagesCount, long reservedBytes, long queryReservedBytesAvg, long queryReservedBytesP50, long queryReservedBytesP90, long queryReservedBytesP95, long stageReservedBytesAvg, long stageReservedBytesP50, long stageReservedBytesP90, long stageReservedBytesP95) {
    }
}

