/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.execution.scheduler;

import com.facebook.presto.OutputBuffers;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.execution.LocationFactory;
import com.facebook.presto.execution.NodeTaskMap;
import com.facebook.presto.execution.QueryState;
import com.facebook.presto.execution.QueryStateMachine;
import com.facebook.presto.execution.RemoteTask;
import com.facebook.presto.execution.RemoteTaskFactory;
import com.facebook.presto.execution.SqlStageExecution;
import com.facebook.presto.execution.StageId;
import com.facebook.presto.execution.StageInfo;
import com.facebook.presto.execution.StageState;
import com.facebook.presto.execution.scheduler.BroadcastOutputBufferManager;
import com.facebook.presto.execution.scheduler.ExecutionPolicy;
import com.facebook.presto.execution.scheduler.ExecutionSchedule;
import com.facebook.presto.execution.scheduler.FixedCountScheduler;
import com.facebook.presto.execution.scheduler.NodeScheduler;
import com.facebook.presto.execution.scheduler.NodeSelector;
import com.facebook.presto.execution.scheduler.OutputBufferManager;
import com.facebook.presto.execution.scheduler.PartitionedOutputBufferManager;
import com.facebook.presto.execution.scheduler.ScheduleResult;
import com.facebook.presto.execution.scheduler.SourcePartitionedScheduler;
import com.facebook.presto.execution.scheduler.SplitPlacementPolicy;
import com.facebook.presto.execution.scheduler.StageScheduler;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.Node;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.split.SplitSource;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.StageExecutionPlan;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.util.ImmutableCollectors;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.concurrent.MoreFutures;
import io.airlift.concurrent.SetThreadName;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

public class SqlQueryScheduler {
    private final QueryStateMachine queryStateMachine;
    private final ExecutionPolicy executionPolicy;
    private final Map<StageId, SqlStageExecution> stages;
    private final ExecutorService executor;
    private final StageId rootStageId;
    private final Map<StageId, StageScheduler> stageSchedulers;
    private final Map<StageId, StageLinkage> stageLinkages;
    private final AtomicBoolean started = new AtomicBoolean();

    public SqlQueryScheduler(QueryStateMachine queryStateMachine, LocationFactory locationFactory, StageExecutionPlan plan, NodeScheduler nodeScheduler, RemoteTaskFactory remoteTaskFactory, Session session, int splitBatchSize, ExecutorService executor, OutputBuffers rootOutputBuffers, NodeTaskMap nodeTaskMap, ExecutionPolicy executionPolicy) {
        this.queryStateMachine = Objects.requireNonNull(queryStateMachine, "queryStateMachine is null");
        this.executionPolicy = Objects.requireNonNull(executionPolicy, "schedulerPolicyFactory is null");
        ImmutableMap.Builder stageSchedulers = ImmutableMap.builder();
        ImmutableMap.Builder stageLinkages = ImmutableMap.builder();
        List<SqlStageExecution> stages = this.createStages(Optional.empty(), new AtomicInteger(), locationFactory, plan, nodeScheduler, remoteTaskFactory, session, splitBatchSize, executor, nodeTaskMap, (ImmutableMap.Builder<StageId, StageScheduler>)stageSchedulers, (ImmutableMap.Builder<StageId, StageLinkage>)stageLinkages);
        SqlStageExecution rootStage = stages.get(0);
        rootStage.setOutputBuffers(rootOutputBuffers);
        this.rootStageId = rootStage.getStageId();
        this.stages = (Map)stages.stream().collect(ImmutableCollectors.toImmutableMap(SqlStageExecution::getStageId));
        this.stageSchedulers = stageSchedulers.build();
        this.stageLinkages = stageLinkages.build();
        this.executor = executor;
        rootStage.addStateChangeListener(state -> {
            if (state == StageState.FINISHED) {
                queryStateMachine.transitionToFinished();
            } else if (state == StageState.CANCELED) {
                queryStateMachine.transitionToFailed((Throwable)new PrestoException((ErrorCodeSupplier)StandardErrorCode.USER_CANCELED, "Query was canceled"));
            }
        });
        for (SqlStageExecution stage : stages) {
            stage.addStateChangeListener(state -> {
                if (queryStateMachine.isDone()) {
                    return;
                }
                if (state == StageState.FAILED) {
                    queryStateMachine.transitionToFailed(stage.getStageInfo().getFailureCause().toException());
                } else if (state == StageState.ABORTED) {
                    queryStateMachine.transitionToFailed((Throwable)new PrestoException((ErrorCodeSupplier)StandardErrorCode.INTERNAL_ERROR, "Query stage was aborted"));
                } else if (queryStateMachine.getQueryState() == QueryState.STARTING && stage.hasTasks()) {
                    queryStateMachine.transitionToRunning();
                }
            });
        }
    }

    private List<SqlStageExecution> createStages(Optional<SqlStageExecution> parent, AtomicInteger nextStageId, LocationFactory locationFactory, StageExecutionPlan plan, NodeScheduler nodeScheduler, RemoteTaskFactory remoteTaskFactory, Session session, int splitBatchSize, ExecutorService executor, NodeTaskMap nodeTaskMap, ImmutableMap.Builder<StageId, StageScheduler> stageSchedulers, ImmutableMap.Builder<StageId, StageLinkage> stageLinkages) {
        NodeSelector nodeSelector;
        ImmutableList.Builder stages = ImmutableList.builder();
        StageId stageId = new StageId(this.queryStateMachine.getQueryId(), String.valueOf(nextStageId.getAndIncrement()));
        SqlStageExecution stage = new SqlStageExecution(stageId, locationFactory.createStageLocation(stageId), plan.getFragment(), remoteTaskFactory, session, nodeTaskMap, executor);
        stages.add((Object)stage);
        OptionalInt partitionCount = OptionalInt.empty();
        if (plan.getFragment().getDistribution() == PlanFragment.PlanDistribution.SINGLE) {
            nodeSelector = nodeScheduler.createNodeSelector(null);
            stageSchedulers.put((Object)stageId, (Object)new FixedCountScheduler(stage, (Map<Integer, Node>)ImmutableMap.of((Object)0, (Object)nodeSelector.selectRandomNodes(1).get(0))));
            partitionCount = OptionalInt.of(1);
        } else if (plan.getFragment().getDistribution() == PlanFragment.PlanDistribution.FIXED) {
            nodeSelector = nodeScheduler.createNodeSelector(null);
            ImmutableMap.Builder partitionToNode = ImmutableMap.builder();
            int partition = 0;
            for (Node node : nodeSelector.selectRandomNodes(SystemSessionProperties.getHashPartitionCount(session))) {
                partitionToNode.put((Object)partition, (Object)node);
                ++partition;
            }
            stageSchedulers.put((Object)stageId, (Object)new FixedCountScheduler(stage, (Map<Integer, Node>)partitionToNode.build()));
            partitionCount = OptionalInt.of(partition);
        } else if (plan.getFragment().getDistribution() == PlanFragment.PlanDistribution.COORDINATOR_ONLY) {
            nodeSelector = nodeScheduler.createNodeSelector(null);
            stageSchedulers.put((Object)stageId, (Object)new FixedCountScheduler(stage, (Map<Integer, Node>)ImmutableMap.of((Object)0, (Object)nodeSelector.selectCurrentNode())));
            partitionCount = OptionalInt.of(1);
        } else if (plan.getFragment().getDistribution() == PlanFragment.PlanDistribution.SOURCE) {
            SplitSource splitSource = plan.getDataSource().get();
            NodeSelector nodeSelector2 = nodeScheduler.createNodeSelector(splitSource.getDataSourceName());
            stageSchedulers.put((Object)stageId, (Object)new SourcePartitionedScheduler(stage, splitSource, new SplitPlacementPolicy(nodeSelector2, stage::getAllTasks), splitBatchSize));
        } else {
            throw new IllegalStateException("Unsupported partitioning: " + (Object)((Object)plan.getFragment().getDistribution()));
        }
        ImmutableSet.Builder childStages = ImmutableSet.builder();
        for (StageExecutionPlan subStagePlan : plan.getSubStages()) {
            List<SqlStageExecution> subTree = this.createStages(Optional.of(stage), nextStageId, locationFactory, subStagePlan.withPartitionCount(partitionCount), nodeScheduler, remoteTaskFactory, session, splitBatchSize, executor, nodeTaskMap, stageSchedulers, stageLinkages);
            stages.addAll(subTree);
            SqlStageExecution childStage = subTree.get(0);
            childStages.add((Object)childStage);
        }
        stageLinkages.put((Object)stageId, (Object)new StageLinkage(plan.getFragment().getId(), parent, (Set<SqlStageExecution>)childStages.build()));
        return stages.build();
    }

    public StageInfo getStageInfo() {
        Map stageInfos = (Map)this.stages.values().stream().map(SqlStageExecution::getStageInfo).collect(ImmutableCollectors.toImmutableMap(StageInfo::getStageId));
        return this.buildStageInfo(this.rootStageId, stageInfos);
    }

    private StageInfo buildStageInfo(StageId stageId, Map<StageId, StageInfo> stageInfos) {
        StageInfo parent = stageInfos.get(stageId);
        Preconditions.checkArgument((parent != null ? 1 : 0) != 0, (String)"No stageInfo for %s", (Object[])new Object[]{parent});
        List childStages = (List)this.stageLinkages.get(stageId).getChildStageIds().stream().map(childStageId -> this.buildStageInfo((StageId)childStageId, stageInfos)).collect(ImmutableCollectors.toImmutableList());
        if (childStages.isEmpty()) {
            return parent;
        }
        return new StageInfo(parent.getStageId(), parent.getState(), parent.getSelf(), parent.getPlan(), parent.getTypes(), parent.getStageStats(), parent.getTasks(), childStages, parent.getFailureCause());
    }

    public long getTotalMemoryReservation() {
        return this.stages.values().stream().mapToLong(SqlStageExecution::getMemoryReservation).sum();
    }

    public void start() {
        if (this.started.compareAndSet(false, true)) {
            this.executor.submit(this::schedule);
        }
    }

    private void schedule() {
        RuntimeException closeError;
        try {
            SetThreadName ignored = new SetThreadName("Query-%s", new Object[]{this.queryStateMachine.getQueryId()});
            Object object = null;
            try {
                HashSet<StageId> completedStages = new HashSet<StageId>();
                ExecutionSchedule executionSchedule = this.executionPolicy.createExecutionSchedule(this.stages.values());
                while (!executionSchedule.isFinished()) {
                    ArrayList blockedStages = new ArrayList();
                    for (SqlStageExecution sqlStageExecution : executionSchedule.getStagesToSchedule()) {
                        sqlStageExecution.beginScheduling();
                        ScheduleResult result = this.stageSchedulers.get(sqlStageExecution.getStageId()).schedule();
                        if (result.isFinished()) {
                            sqlStageExecution.schedulingComplete();
                        } else if (!result.getBlocked().isDone()) {
                            blockedStages.add(result.getBlocked());
                        }
                        this.stageLinkages.get(sqlStageExecution.getStageId()).processScheduleResults(sqlStageExecution.getState(), result.getNewTasks());
                    }
                    for (SqlStageExecution sqlStageExecution : this.stages.values()) {
                        if (completedStages.contains(sqlStageExecution.getStageId()) || !sqlStageExecution.getState().isDone()) continue;
                        this.stageLinkages.get(sqlStageExecution.getStageId()).processScheduleResults(sqlStageExecution.getState(), (Set<RemoteTask>)ImmutableSet.of());
                        completedStages.add(sqlStageExecution.getStageId());
                    }
                    if (blockedStages.isEmpty()) continue;
                    MoreFutures.tryGetFutureValue((Future)MoreFutures.firstCompletedFuture(blockedStages), (int)100, (TimeUnit)TimeUnit.MILLISECONDS);
                    for (CompletableFuture completableFuture : blockedStages) {
                        completableFuture.cancel(true);
                    }
                }
                for (SqlStageExecution stage : this.stages.values()) {
                    StageState stageState = stage.getState();
                    if (stageState == StageState.SCHEDULED || stageState == StageState.RUNNING || stageState.isDone()) continue;
                    throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.INTERNAL_ERROR, String.format("Scheduling is complete, but stage %s is in state %s", new Object[]{stage.getStageId(), stageState}));
                }
            }
            catch (Throwable completedStages) {
                object = completedStages;
                throw completedStages;
            }
            finally {
                if (ignored != null) {
                    if (object != null) {
                        try {
                            ignored.close();
                        }
                        catch (Throwable completedStages) {
                            ((Throwable)object).addSuppressed(completedStages);
                        }
                    } else {
                        ignored.close();
                    }
                }
            }
            closeError = new RuntimeException();
        }
        catch (Throwable t) {
            try {
                this.queryStateMachine.transitionToFailed(t);
                throw Throwables.propagate((Throwable)t);
            }
            catch (Throwable throwable) {
                RuntimeException closeError2 = new RuntimeException();
                for (StageScheduler scheduler : this.stageSchedulers.values()) {
                    try {
                        scheduler.close();
                    }
                    catch (Throwable t2) {
                        this.queryStateMachine.transitionToFailed(t2);
                        closeError2.addSuppressed(t2);
                    }
                }
                if (closeError2.getSuppressed().length > 0) {
                    throw closeError2;
                }
                throw throwable;
            }
        }
        for (StageScheduler scheduler : this.stageSchedulers.values()) {
            try {
                scheduler.close();
            }
            catch (Throwable t) {
                this.queryStateMachine.transitionToFailed(t);
                closeError.addSuppressed(t);
            }
        }
        if (closeError.getSuppressed().length > 0) {
            throw closeError;
        }
    }

    public void cancelStage(StageId stageId) {
        try (SetThreadName ignored = new SetThreadName("Query-%s", new Object[]{this.queryStateMachine.getQueryId()});){
            SqlStageExecution sqlStageExecution = this.stages.get(stageId);
            SqlStageExecution stage = Objects.requireNonNull(sqlStageExecution, () -> String.format("Stage %s does not exist", stageId));
            stage.cancel();
        }
    }

    public void abort() {
        try (SetThreadName ignored = new SetThreadName("Query-%s", new Object[]{this.queryStateMachine.getQueryId()});){
            this.stages.values().stream().forEach(SqlStageExecution::abort);
        }
    }

    private static class StageLinkage {
        private final PlanFragmentId currentStageFragmentId;
        private final Optional<SqlStageExecution> parent;
        private final Set<OutputBufferManager> childOutputBufferManagers;
        private final Set<StageId> childStageIds;

        public StageLinkage(PlanFragmentId fragmentId, Optional<SqlStageExecution> parent, Set<SqlStageExecution> children) {
            this.currentStageFragmentId = fragmentId;
            this.parent = parent;
            this.childOutputBufferManagers = (Set)children.stream().map(childStage -> {
                if (childStage.getFragment().getPartitionFunction().isPresent()) {
                    return new PartitionedOutputBufferManager(childStage::setOutputBuffers);
                }
                return new BroadcastOutputBufferManager(childStage::setOutputBuffers);
            }).collect(ImmutableCollectors.toImmutableSet());
            this.childStageIds = (Set)children.stream().map(SqlStageExecution::getStageId).collect(ImmutableCollectors.toImmutableSet());
        }

        public Set<StageId> getChildStageIds() {
            return this.childStageIds;
        }

        public void processScheduleResults(StageState newState, Set<RemoteTask> newTasks) {
            for (RemoteTask remoteTask : newTasks) {
                if (this.parent.isPresent()) {
                    this.parent.get().addExchangeLocation(new SqlStageExecution.ExchangeLocation(this.currentStageFragmentId, remoteTask.getTaskInfo().getSelf()));
                }
                this.childOutputBufferManagers.forEach(child -> child.addOutputBuffer(remoteTask.getTaskId(), remoteTask.getPartition()));
            }
            switch (newState) {
                case PLANNED: 
                case SCHEDULING: {
                    break;
                }
                case SCHEDULING_SPLITS: 
                case SCHEDULED: 
                case RUNNING: 
                case FINISHED: 
                case CANCELED: {
                    if (this.parent.isPresent()) {
                        this.parent.get().noMoreExchangeLocationsFor(this.currentStageFragmentId);
                    }
                    this.childOutputBufferManagers.forEach(OutputBufferManager::noMoreOutputBuffers);
                }
            }
        }
    }
}

