/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.BatchExecutionOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.metrics.Counter;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.SimpleCounter;
import org.apache.flink.runtime.blocklist.BlockedNode;
import org.apache.flink.runtime.blocklist.BlocklistOperations;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.SpeculativeExecutionVertex;
import org.apache.flink.runtime.io.network.partition.PartitionException;
import org.apache.flink.runtime.scheduler.adaptivebatch.SpeculativeExecutionHandler;
import org.apache.flink.runtime.scheduler.slowtaskdetector.ExecutionTimeBasedSlowTaskDetector;
import org.apache.flink.runtime.scheduler.slowtaskdetector.SlowTaskDetector;
import org.apache.flink.runtime.scheduler.slowtaskdetector.SlowTaskDetectorListener;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.concurrent.FutureUtils;
import org.slf4j.Logger;

public class DefaultSpeculativeExecutionHandler
implements SpeculativeExecutionHandler,
SlowTaskDetectorListener {
    private final int maxConcurrentExecutions;
    private final Duration blockSlowNodeDuration;
    private final BlocklistOperations blocklistOperations;
    private final SlowTaskDetector slowTaskDetector;
    private long numSlowExecutionVertices;
    private final Counter numEffectiveSpeculativeExecutionsCounter;
    private final Function<ExecutionVertexID, ExecutionVertex> executionVertexRetriever;
    private final Supplier<Map<ExecutionAttemptID, Execution>> registerExecutionsSupplier;
    private final BiConsumer<List<Execution>, Collection<ExecutionVertexID>> allocateSlotsAndDeployFunction;
    private final Logger log;

    public DefaultSpeculativeExecutionHandler(Configuration jobMasterConfiguration, BlocklistOperations blocklistOperations, Function<ExecutionVertexID, ExecutionVertex> executionVertexRetriever, Supplier<Map<ExecutionAttemptID, Execution>> registerExecutionsSupplier, BiConsumer<List<Execution>, Collection<ExecutionVertexID>> allocateSlotsAndDeployFunction, Logger log) {
        this.maxConcurrentExecutions = jobMasterConfiguration.get(BatchExecutionOptions.SPECULATIVE_MAX_CONCURRENT_EXECUTIONS);
        this.blockSlowNodeDuration = jobMasterConfiguration.get(BatchExecutionOptions.BLOCK_SLOW_NODE_DURATION);
        Preconditions.checkArgument(!this.blockSlowNodeDuration.isNegative(), "The blocking duration should not be negative.");
        this.blocklistOperations = Preconditions.checkNotNull(blocklistOperations);
        this.slowTaskDetector = new ExecutionTimeBasedSlowTaskDetector(jobMasterConfiguration);
        this.numEffectiveSpeculativeExecutionsCounter = new SimpleCounter();
        this.executionVertexRetriever = Preconditions.checkNotNull(executionVertexRetriever);
        this.registerExecutionsSupplier = Preconditions.checkNotNull(registerExecutionsSupplier);
        this.allocateSlotsAndDeployFunction = Preconditions.checkNotNull(allocateSlotsAndDeployFunction);
        this.log = Preconditions.checkNotNull(log);
    }

    @Override
    public void init(ExecutionGraph executionGraph, ComponentMainThreadExecutor mainThreadExecutor, MetricGroup metricGroup) {
        metricGroup.gauge("numSlowExecutionVertices", () -> this.numSlowExecutionVertices);
        metricGroup.counter("numEffectiveSpeculativeExecutions", this.numEffectiveSpeculativeExecutionsCounter);
        this.slowTaskDetector.start(executionGraph, this, mainThreadExecutor);
    }

    @Override
    public void stopSlowTaskDetector() {
        this.slowTaskDetector.stop();
    }

    @Override
    public void notifyTaskFinished(Execution execution, Function<ExecutionVertexID, CompletableFuture<?>> cancelPendingExecutionsFunction) {
        if (!this.isOriginalAttempt(execution)) {
            this.numEffectiveSpeculativeExecutionsCounter.inc();
        }
        FutureUtils.assertNoException(cancelPendingExecutionsFunction.apply(execution.getVertex().getID()));
    }

    private boolean isOriginalAttempt(Execution execution) {
        return this.getExecutionVertex(execution.getVertex().getID()).isOriginalAttempt(execution.getAttemptNumber());
    }

    @Override
    public void notifyTaskFailed(Execution execution) {
        SpeculativeExecutionVertex executionVertex = this.getExecutionVertex(execution.getVertex().getID());
        executionVertex.archiveFailedExecution(execution.getAttemptId());
    }

    @Override
    public boolean handleTaskFailure(Execution failedExecution, @Nullable Throwable error, BiConsumer<Execution, Throwable> handleLocalExecutionAttemptFailure) {
        SpeculativeExecutionVertex executionVertex = this.getExecutionVertex(failedExecution.getVertex().getID());
        if (!DefaultSpeculativeExecutionHandler.isExecutionVertexPossibleToFinish(executionVertex) || ExceptionUtils.findThrowable(error, PartitionException.class).isPresent()) {
            return false;
        }
        handleLocalExecutionAttemptFailure.accept(failedExecution, error);
        return true;
    }

    private static boolean isExecutionVertexPossibleToFinish(SpeculativeExecutionVertex executionVertex) {
        boolean anyExecutionPossibleToFinish = false;
        for (Execution execution : executionVertex.getCurrentExecutions()) {
            Preconditions.checkState(execution.getState() != ExecutionState.FINISHED);
            if (execution.getState() != ExecutionState.CREATED && execution.getState() != ExecutionState.SCHEDULED && execution.getState() != ExecutionState.DEPLOYING && execution.getState() != ExecutionState.INITIALIZING && execution.getState() != ExecutionState.RUNNING) continue;
            anyExecutionPossibleToFinish = true;
        }
        return anyExecutionPossibleToFinish;
    }

    @Override
    public void notifySlowTasks(Map<ExecutionVertexID, Collection<ExecutionAttemptID>> slowTasks) {
        long currentTimestamp = System.currentTimeMillis();
        this.numSlowExecutionVertices = slowTasks.size();
        this.blockSlowNodes(slowTasks, currentTimestamp);
        ArrayList newSpeculativeExecutions = new ArrayList();
        HashSet<ExecutionVertexID> verticesToDeploy = new HashSet<ExecutionVertexID>();
        for (ExecutionVertexID executionVertexId : slowTasks.keySet()) {
            int currentConcurrentExecutions;
            int newSpeculativeExecutionsToDeploy;
            SpeculativeExecutionVertex executionVertex = this.getExecutionVertex(executionVertexId);
            if (!executionVertex.isSupportsConcurrentExecutionAttempts() || (newSpeculativeExecutionsToDeploy = this.maxConcurrentExecutions - (currentConcurrentExecutions = executionVertex.getCurrentExecutions().size())) <= 0) continue;
            this.log.info("{} ({}) is detected as a slow vertex, create and deploy {} new speculative executions for it.", new Object[]{executionVertex.getTaskNameWithSubtaskIndex(), executionVertex.getID(), newSpeculativeExecutionsToDeploy});
            Collection attempts = IntStream.range(0, newSpeculativeExecutionsToDeploy).mapToObj(i -> executionVertex.createNewSpeculativeExecution(currentTimestamp)).collect(Collectors.toList());
            this.setupSubtaskGatewayForAttempts(executionVertex, attempts);
            verticesToDeploy.add(executionVertexId);
            newSpeculativeExecutions.addAll(attempts);
        }
        this.allocateSlotsAndDeployFunction.accept(newSpeculativeExecutions, verticesToDeploy);
    }

    private void blockSlowNodes(Map<ExecutionVertexID, Collection<ExecutionAttemptID>> slowTasks, long currentTimestamp) {
        if (!this.blockSlowNodeDuration.isZero()) {
            long blockedEndTimestamp = currentTimestamp + this.blockSlowNodeDuration.toMillis();
            Collection nodesToBlock = this.getSlowNodeIds(slowTasks).stream().map(nodeId -> new BlockedNode((String)nodeId, "Node is detected to be slow.", blockedEndTimestamp)).collect(Collectors.toList());
            this.blocklistOperations.addNewBlockedNodes(nodesToBlock);
        }
    }

    private Set<String> getSlowNodeIds(Map<ExecutionVertexID, Collection<ExecutionAttemptID>> slowTasks) {
        Set slowExecutions = slowTasks.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
        return slowExecutions.stream().map(id -> this.registerExecutionsSupplier.get().get(id)).map(e -> {
            Preconditions.checkNotNull(e.getAssignedResource(), "The reported slow node have not been assigned a slot. This is unexpected and indicates that there is something wrong with the slow task detector.");
            return e.getAssignedResourceLocation();
        }).map(TaskManagerLocation::getNodeId).collect(Collectors.toSet());
    }

    private SpeculativeExecutionVertex getExecutionVertex(ExecutionVertexID executionVertexId) {
        return (SpeculativeExecutionVertex)this.executionVertexRetriever.apply(executionVertexId);
    }

    private void setupSubtaskGatewayForAttempts(SpeculativeExecutionVertex executionVertex, Collection<Execution> attempts) {
        Set attemptNumbers = attempts.stream().map(Execution::getAttemptNumber).collect(Collectors.toSet());
        executionVertex.getJobVertex().getOperatorCoordinators().forEach(operatorCoordinator -> operatorCoordinator.setupSubtaskGatewayForAttempts(executionVertex.getParallelSubtaskIndex(), attemptNumbers));
    }

    @Override
    public void resetForNewExecution(ExecutionVertexID executionVertexId) {
        SpeculativeExecutionVertex executionVertex = this.getExecutionVertex(executionVertexId);
        Execution execution = ((ExecutionVertex)executionVertex).getCurrentExecutionAttempt();
        if (execution.getState() == ExecutionState.FINISHED && !this.isOriginalAttempt(execution)) {
            this.numEffectiveSpeculativeExecutionsCounter.dec();
        }
    }

    @VisibleForTesting
    long getNumSlowExecutionVertices() {
        return this.numSlowExecutionVertices;
    }

    @VisibleForTesting
    long getNumEffectiveSpeculativeExecutions() {
        return this.numEffectiveSpeculativeExecutionsCounter.getCount();
    }
}

