/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.execution;

import com.facebook.airlift.concurrent.Threads;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.execution.MemoryRevokingSchedulerUtils;
import com.facebook.presto.execution.MemoryRevokingUtils;
import com.facebook.presto.execution.SqlTask;
import com.facebook.presto.execution.SqlTaskManager;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskState;
import com.facebook.presto.memory.LocalMemoryManager;
import com.facebook.presto.memory.MemoryPool;
import com.facebook.presto.memory.MemoryPoolListener;
import com.facebook.presto.memory.QueryContext;
import com.facebook.presto.memory.QueryContextVisitor;
import com.facebook.presto.memory.VoidTraversingQueryContextVisitor;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.PipelineContext;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.memory.MemoryPoolId;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import jakarta.inject.Inject;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.Supplier;

public class MemoryRevokingScheduler {
    private static final Logger log = Logger.get(MemoryRevokingScheduler.class);
    private static final Ordering<SqlTask> ORDER_BY_CREATE_TIME = Ordering.natural().onResultOf(SqlTask::getTaskCreatedTime);
    private final Function<QueryId, QueryContext> queryContextSupplier;
    private final Supplier<List<SqlTask>> currentTasksSupplier;
    private final ExecutorService memoryRevocationExecutor;
    private final double memoryRevokingThreshold;
    private final double memoryRevokingTarget;
    private final FeaturesConfig.TaskSpillingStrategy spillingStrategy;
    private final List<MemoryPool> memoryPools;
    private final MemoryPoolListener memoryPoolListener = this::onMemoryReserved;
    private final boolean queryLimitSpillEnabled;

    @Inject
    public MemoryRevokingScheduler(LocalMemoryManager localMemoryManager, SqlTaskManager sqlTaskManager, FeaturesConfig config) {
        this((List<MemoryPool>)ImmutableList.copyOf((Collection)MemoryRevokingUtils.getMemoryPools((LocalMemoryManager)localMemoryManager)), () -> ((SqlTaskManager)Objects.requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")).getAllTasks(), arg_0 -> ((SqlTaskManager)Objects.requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")).getQueryContext(arg_0), config.getMemoryRevokingThreshold(), config.getMemoryRevokingTarget(), config.getTaskSpillingStrategy(), config.isQueryLimitSpillEnabled());
    }

    @VisibleForTesting
    MemoryRevokingScheduler(List<MemoryPool> memoryPools, Supplier<List<SqlTask>> currentTasksSupplier, Function<QueryId, QueryContext> queryContextSupplier, double memoryRevokingThreshold, double memoryRevokingTarget, FeaturesConfig.TaskSpillingStrategy taskSpillingStrategy, boolean queryLimitSpillEnabled) {
        this.memoryPools = ImmutableList.copyOf((Collection)Objects.requireNonNull(memoryPools, "memoryPools is null"));
        this.currentTasksSupplier = Objects.requireNonNull(currentTasksSupplier, "allTasksSupplier is null");
        this.queryContextSupplier = Objects.requireNonNull(queryContextSupplier, "queryContextSupplier is null");
        this.memoryRevokingThreshold = MemoryRevokingScheduler.checkFraction(memoryRevokingThreshold, "memoryRevokingThreshold");
        this.memoryRevokingTarget = MemoryRevokingScheduler.checkFraction(memoryRevokingTarget, "memoryRevokingTarget");
        this.memoryRevocationExecutor = Executors.newSingleThreadExecutor(Threads.threadsNamed((String)"memory-revocation"));
        this.spillingStrategy = Objects.requireNonNull(taskSpillingStrategy, "taskSpillingStrategy is null");
        Preconditions.checkArgument((this.spillingStrategy != FeaturesConfig.TaskSpillingStrategy.PER_TASK_MEMORY_THRESHOLD ? 1 : 0) != 0, (Object)"spilling strategy cannot be PER_TASK_MEMORY_THRESHOLD in MemoryRevokingScheduler");
        Preconditions.checkArgument((memoryRevokingTarget <= memoryRevokingThreshold ? 1 : 0) != 0, (String)"memoryRevokingTarget should be less than or equal memoryRevokingThreshold, but got %s and %s respectively", (Object)memoryRevokingTarget, (Object)memoryRevokingThreshold);
        this.queryLimitSpillEnabled = queryLimitSpillEnabled;
    }

    private static double checkFraction(double value, String valueName) {
        Objects.requireNonNull(valueName, "valueName is null");
        Preconditions.checkArgument((0.0 <= value && value <= 1.0 ? 1 : 0) != 0, (String)"%s should be within [0, 1] range, got %s", (Object)valueName, (Object)value);
        return value;
    }

    @PostConstruct
    public void start() {
        this.registerPoolListeners();
    }

    @PreDestroy
    public void stop() {
        this.memoryPools.forEach(memoryPool -> memoryPool.removeListener(this.memoryPoolListener));
        this.memoryRevocationExecutor.shutdown();
    }

    private void registerPoolListeners() {
        this.memoryPools.forEach(memoryPool -> memoryPool.addListener(this.memoryPoolListener));
    }

    @VisibleForTesting
    void awaitAsynchronousCallbacksRun() throws InterruptedException {
        this.memoryRevocationExecutor.invokeAll(Collections.singletonList(() -> null));
    }

    @VisibleForTesting
    void submitAsynchronousCallable(Callable<?> callable) {
        this.memoryRevocationExecutor.submit(callable);
    }

    private void onMemoryReserved(MemoryPool memoryPool, QueryId queryId, long queryMemoryReservation) {
        try {
            if (this.queryLimitSpillEnabled) {
                QueryContext queryContext = this.queryContextSupplier.apply(queryId);
                Verify.verify((queryContext != null ? 1 : 0) != 0, (String)"QueryContext not found for queryId %s", (Object)queryId);
                long maxTotalMemory = queryContext.getMaxTotalMemory();
                if (this.memoryRevokingNeededForQuery(queryMemoryReservation, maxTotalMemory)) {
                    log.debug("Scheduling check for %s", new Object[]{queryId});
                    this.scheduleQueryRevoking(queryContext, maxTotalMemory);
                }
            }
            if (this.memoryRevokingNeededForPool(memoryPool)) {
                log.debug("Scheduling check for %s", new Object[]{memoryPool});
                this.scheduleMemoryPoolRevoking(memoryPool);
            }
        }
        catch (Exception e) {
            log.error((Throwable)e, "Error when acting on memory pool reservation");
        }
    }

    private boolean memoryRevokingNeededForQuery(long queryMemoryReservation, long maxTotalMemory) {
        return queryMemoryReservation >= maxTotalMemory;
    }

    private void scheduleQueryRevoking(QueryContext queryContext, long maxTotalMemory) {
        this.memoryRevocationExecutor.execute(() -> {
            try {
                this.revokeQueryMemory(queryContext, maxTotalMemory);
            }
            catch (Exception e) {
                log.error((Throwable)e, "Error requesting memory revoking");
            }
        });
    }

    private void revokeQueryMemory(QueryContext queryContext, long maxTotalMemory) {
        QueryId queryId = queryContext.getQueryId();
        MemoryPool memoryPool = queryContext.getMemoryPool();
        long queryTotalMemory = MemoryRevokingScheduler.getTotalQueryMemoryReservation(queryId, memoryPool);
        TreeMap queryTaskContextsMap = new TreeMap(Comparator.reverseOrder());
        queryContext.getAllTaskContexts().forEach(taskContext -> queryTaskContextsMap.put(taskContext.getTaskMemoryContext().getRevocableMemory(), taskContext));
        AtomicLong remainingBytesToRevoke = new AtomicLong(queryTotalMemory - maxTotalMemory);
        Collection queryTaskContexts = queryTaskContextsMap.values();
        remainingBytesToRevoke.addAndGet(-MemoryRevokingSchedulerUtils.getMemoryAlreadyBeingRevoked(queryTaskContexts, (long)remainingBytesToRevoke.get()));
        for (final TaskContext taskContext2 : queryTaskContexts) {
            if (remainingBytesToRevoke.get() <= 0L) break;
            taskContext2.accept((QueryContextVisitor)new VoidTraversingQueryContextVisitor<AtomicLong>(){

                public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke) {
                    long revokedBytes;
                    if (remainingBytesToRevoke.get() > 0L && (revokedBytes = operatorContext.requestMemoryRevoking()) > 0L) {
                        remainingBytesToRevoke.addAndGet(-revokedBytes);
                        log.debug("taskId=%s: requested revoking %s; remaining %s", new Object[]{taskContext2.getTaskId(), revokedBytes, remainingBytesToRevoke});
                    }
                    return null;
                }
            }, (Object)remainingBytesToRevoke);
        }
    }

    private static long getTotalQueryMemoryReservation(QueryId queryId, MemoryPool memoryPool) {
        return memoryPool.getQueryMemoryReservation(queryId) + memoryPool.getQueryRevocableMemoryReservation(queryId);
    }

    private void scheduleMemoryPoolRevoking(MemoryPool memoryPool) {
        this.memoryRevocationExecutor.execute(() -> {
            try {
                this.runMemoryPoolRevoking(memoryPool);
            }
            catch (Exception e) {
                log.error((Throwable)e, "Error requesting memory revoking");
            }
        });
    }

    @VisibleForTesting
    void runMemoryPoolRevoking(MemoryPool memoryPool) {
        if (!this.memoryRevokingNeededForPool(memoryPool)) {
            return;
        }
        Collection allTasks = Objects.requireNonNull(this.currentTasksSupplier.get());
        this.requestMemoryPoolRevoking(memoryPool, allTasks);
    }

    private void requestMemoryPoolRevoking(MemoryPool memoryPool, Collection<SqlTask> allTasks) {
        long remainingBytesToRevoke = (long)((double)(-memoryPool.getFreeBytes()) + (double)memoryPool.getMaxBytes() * (1.0 - this.memoryRevokingTarget));
        ArrayList<SqlTask> runningTasksInPool = MemoryRevokingScheduler.findRunningTasksInMemoryPool(allTasks, memoryPool);
        if ((remainingBytesToRevoke -= this.getMemoryAlreadyBeingRevoked(runningTasksInPool, remainingBytesToRevoke)) > 0L) {
            this.requestRevoking(memoryPool.getId(), runningTasksInPool, remainingBytesToRevoke);
        }
    }

    private boolean memoryRevokingNeededForPool(MemoryPool memoryPool) {
        return memoryPool.getReservedRevocableBytes() > 0L && (double)memoryPool.getFreeBytes() <= (double)memoryPool.getMaxBytes() * (1.0 - this.memoryRevokingThreshold);
    }

    private long getMemoryAlreadyBeingRevoked(List<SqlTask> sqlTasks, long targetRevokingLimit) {
        List taskContexts = (List)sqlTasks.stream().map(SqlTask::getTaskContext).filter(Optional::isPresent).map(Optional::get).collect(ImmutableList.toImmutableList());
        return MemoryRevokingSchedulerUtils.getMemoryAlreadyBeingRevoked((Collection)taskContexts, (long)targetRevokingLimit);
    }

    private void requestRevoking(final MemoryPoolId memoryPoolId, ArrayList<SqlTask> sqlTasks, long remainingBytesToRevoke) {
        VoidTraversingQueryContextVisitor<AtomicLong> visitor = new VoidTraversingQueryContextVisitor<AtomicLong>(){

            public Void visitPipelineContext(PipelineContext pipelineContext, AtomicLong remainingBytesToRevoke) {
                if (remainingBytesToRevoke.get() <= 0L) {
                    return null;
                }
                return (Void)super.visitPipelineContext(pipelineContext, (Object)remainingBytesToRevoke);
            }

            public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke) {
                long revokedBytes;
                if (remainingBytesToRevoke.get() > 0L && (revokedBytes = operatorContext.requestMemoryRevoking()) > 0L) {
                    remainingBytesToRevoke.addAndGet(-revokedBytes);
                    log.debug("memoryPool=%s, operatorContext: %s: requested revoking %s; remaining %s", new Object[]{memoryPoolId, operatorContext, revokedBytes, remainingBytesToRevoke.get()});
                }
                return null;
            }
        };
        log.debug("Ordering by %s", new Object[]{this.spillingStrategy});
        MemoryRevokingScheduler.sortTasksToTraversalOrder(sqlTasks, this.spillingStrategy);
        AtomicLong remainingBytesToRevokeAtomic = new AtomicLong(remainingBytesToRevoke);
        for (SqlTask task : sqlTasks) {
            Optional taskContext = task.getTaskContext();
            if (!taskContext.isPresent()) continue;
            ((TaskContext)taskContext.get()).accept((QueryContextVisitor)visitor, (Object)remainingBytesToRevokeAtomic);
            if (remainingBytesToRevokeAtomic.get() > 0L) continue;
            return;
        }
    }

    private static void sortTasksToTraversalOrder(ArrayList<SqlTask> sqlTasks, FeaturesConfig.TaskSpillingStrategy spillingStrategy) {
        switch (spillingStrategy) {
            case ORDER_BY_CREATE_TIME: {
                sqlTasks.sort((Comparator<SqlTask>)ORDER_BY_CREATE_TIME);
                break;
            }
            case ORDER_BY_REVOCABLE_BYTES: {
                HashMap<TaskId, Long> taskRevocableReservations = new HashMap<TaskId, Long>();
                for (SqlTask sqlTask : sqlTasks) {
                    taskRevocableReservations.put(sqlTask.getTaskId(), sqlTask.getTaskInfo().getStats().getRevocableMemoryReservationInBytes());
                }
                sqlTasks.sort((Comparator<SqlTask>)Ordering.natural().reverse().onResultOf(task -> Long.valueOf(task == null ? 0L : taskRevocableReservations.getOrDefault(task.getTaskId(), 0L))));
                break;
            }
            case PER_TASK_MEMORY_THRESHOLD: {
                throw new IllegalArgumentException("spilling strategy cannot be PER_TASK_MEMORY_THRESHOLD in MemoryRevokingScheduler");
            }
            default: {
                throw new UnsupportedOperationException("Unexpected spilling strategy in MemoryRevokingScheduler");
            }
        }
    }

    private static ArrayList<SqlTask> findRunningTasksInMemoryPool(Collection<SqlTask> allCurrentTasks, MemoryPool memoryPool) {
        ArrayList<SqlTask> sqlTasks = new ArrayList<SqlTask>();
        allCurrentTasks.stream().filter(task -> task.getTaskState() == TaskState.RUNNING && task.getQueryContext().getMemoryPool() == memoryPool).forEach(sqlTasks::add);
        return sqlTasks;
    }
}

