/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.JobStatus;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.core.failure.FailureEnricher;
import org.apache.flink.runtime.JobException;
import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
import org.apache.flink.runtime.checkpoint.CheckpointsCleaner;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.execution.SuppressRestartsException;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IOMetrics;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.executiongraph.JobStatusListener;
import org.apache.flink.runtime.executiongraph.MarkPartitionFinishedStrategy;
import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos;
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
import org.apache.flink.runtime.executiongraph.failover.FailoverStrategy;
import org.apache.flink.runtime.executiongraph.failover.RestartBackoffTimeStrategy;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup;
import org.apache.flink.runtime.jobgraph.jsonplan.JsonPlanGenerator;
import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalResult;
import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalTopology;
import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalVertex;
import org.apache.flink.runtime.metrics.groups.JobManagerJobMetricGroup;
import org.apache.flink.runtime.scheduler.DefaultExecutionDeployer;
import org.apache.flink.runtime.scheduler.DefaultScheduler;
import org.apache.flink.runtime.scheduler.ExecutionGraphFactory;
import org.apache.flink.runtime.scheduler.ExecutionOperations;
import org.apache.flink.runtime.scheduler.ExecutionSlotAllocatorFactory;
import org.apache.flink.runtime.scheduler.ExecutionVertexVersion;
import org.apache.flink.runtime.scheduler.ExecutionVertexVersioner;
import org.apache.flink.runtime.scheduler.VertexParallelismStore;
import org.apache.flink.runtime.scheduler.adaptivebatch.AllToAllBlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.PointwiseBlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.VertexParallelismAndInputInfosDecider;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.scheduler.strategy.SchedulingStrategyFactory;
import org.apache.flink.runtime.shuffle.ShuffleMaster;
import org.apache.flink.runtime.source.coordinator.SourceCoordinator;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.concurrent.FutureUtils;
import org.apache.flink.util.concurrent.ScheduledExecutor;
import org.slf4j.Logger;

public class AdaptiveBatchScheduler
extends DefaultScheduler {
    private final DefaultLogicalTopology logicalTopology;
    private final VertexParallelismAndInputInfosDecider vertexParallelismAndInputInfosDecider;
    private final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId;
    private final Map<IntermediateDataSetID, BlockingResultInfo> blockingResultInfos;
    private final JobManagerOptions.HybridPartitionDataConsumeConstraint hybridPartitionDataConsumeConstraint;
    private final Map<JobVertexID, CompletableFuture<Integer>> sourceParallelismFuturesByJobVertexId;

    public AdaptiveBatchScheduler(Logger log, JobGraph jobGraph, Executor ioExecutor, Configuration jobMasterConfiguration, Consumer<ComponentMainThreadExecutor> startUpAction, ScheduledExecutor delayExecutor, ClassLoader userCodeLoader, CheckpointsCleaner checkpointsCleaner, CheckpointRecoveryFactory checkpointRecoveryFactory, JobManagerJobMetricGroup jobManagerJobMetricGroup, SchedulingStrategyFactory schedulingStrategyFactory, FailoverStrategy.Factory failoverStrategyFactory, RestartBackoffTimeStrategy restartBackoffTimeStrategy, ExecutionOperations executionOperations, ExecutionVertexVersioner executionVertexVersioner, ExecutionSlotAllocatorFactory executionSlotAllocatorFactory, long initializationTimestamp, ComponentMainThreadExecutor mainThreadExecutor, JobStatusListener jobStatusListener, Collection<FailureEnricher> failureEnrichers, ExecutionGraphFactory executionGraphFactory, ShuffleMaster<?> shuffleMaster, Time rpcTimeout, VertexParallelismAndInputInfosDecider vertexParallelismAndInputInfosDecider, int defaultMaxParallelism, JobManagerOptions.HybridPartitionDataConsumeConstraint hybridPartitionDataConsumeConstraint, Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId) throws Exception {
        super(log, jobGraph, ioExecutor, jobMasterConfiguration, startUpAction, delayExecutor, userCodeLoader, checkpointsCleaner, checkpointRecoveryFactory, jobManagerJobMetricGroup, schedulingStrategyFactory, failoverStrategyFactory, restartBackoffTimeStrategy, executionOperations, executionVertexVersioner, executionSlotAllocatorFactory, initializationTimestamp, mainThreadExecutor, jobStatusListener, failureEnrichers, executionGraphFactory, shuffleMaster, rpcTimeout, AdaptiveBatchScheduler.computeVertexParallelismStoreForDynamicGraph(jobGraph.getVertices(), defaultMaxParallelism), new DefaultExecutionDeployer.Factory());
        this.logicalTopology = DefaultLogicalTopology.fromJobGraph(jobGraph);
        this.vertexParallelismAndInputInfosDecider = Preconditions.checkNotNull(vertexParallelismAndInputInfosDecider);
        this.forwardGroupsByJobVertexId = Preconditions.checkNotNull(forwardGroupsByJobVertexId);
        this.blockingResultInfos = new HashMap<IntermediateDataSetID, BlockingResultInfo>();
        this.hybridPartitionDataConsumeConstraint = hybridPartitionDataConsumeConstraint;
        this.sourceParallelismFuturesByJobVertexId = new HashMap<JobVertexID, CompletableFuture<Integer>>();
    }

    @Override
    protected void startSchedulingInternal() {
        this.tryComputeSourceParallelismThenRunAsync((value, throwable) -> {
            if (this.getExecutionGraph().getState() == JobStatus.CREATED) {
                this.initializeVerticesIfPossible();
                super.startSchedulingInternal();
            }
        });
    }

    @Override
    protected void onTaskFinished(Execution execution, IOMetrics ioMetrics) {
        Preconditions.checkNotNull(ioMetrics);
        this.updateResultPartitionBytesMetrics(ioMetrics.getResultPartitionBytes());
        ExecutionVertexVersion currentVersion = this.executionVertexVersioner.getExecutionVertexVersion(execution.getVertex().getID());
        this.tryComputeSourceParallelismThenRunAsync((value, throwable) -> {
            if (this.executionVertexVersioner.isModified(currentVersion)) {
                this.log.debug("Initialization of vertices will be skipped, because the execution vertex version has been modified.");
                return;
            }
            this.initializeVerticesIfPossible();
            super.onTaskFinished(execution, ioMetrics);
        });
    }

    private void updateResultPartitionBytesMetrics(Map<IntermediateResultPartitionID, ResultPartitionBytes> resultPartitionBytes) {
        Preconditions.checkNotNull(resultPartitionBytes);
        resultPartitionBytes.forEach((partitionId, partitionBytes) -> {
            IntermediateResult result = this.getExecutionGraph().getAllIntermediateResults().get(partitionId.getIntermediateDataSetID());
            Preconditions.checkNotNull(result);
            this.blockingResultInfos.compute(result.getId(), (ignored, resultInfo) -> {
                if (resultInfo == null) {
                    resultInfo = AdaptiveBatchScheduler.createFromIntermediateResult(result);
                }
                resultInfo.recordPartitionInfo(partitionId.getPartitionNumber(), (ResultPartitionBytes)partitionBytes);
                return resultInfo;
            });
        });
    }

    @Override
    public void allocateSlotsAndDeploy(List<ExecutionVertexID> verticesToDeploy) {
        List<ExecutionVertex> executionVertices = verticesToDeploy.stream().map(this::getExecutionVertex).collect(Collectors.toList());
        this.enrichInputBytesForExecutionVertices(executionVertices);
        super.allocateSlotsAndDeploy(verticesToDeploy);
    }

    @Override
    protected void resetForNewExecution(ExecutionVertexID executionVertexId) {
        ExecutionVertex executionVertex = this.getExecutionVertex(executionVertexId);
        if (executionVertex.getExecutionState() == ExecutionState.FINISHED) {
            executionVertex.getProducedPartitions().values().forEach(partition -> this.blockingResultInfos.computeIfPresent(partition.getIntermediateResult().getId(), (ignored, resultInfo) -> {
                resultInfo.resetPartitionInfo(partition.getPartitionNumber());
                return resultInfo;
            }));
        }
        super.resetForNewExecution(executionVertexId);
    }

    @Override
    protected MarkPartitionFinishedStrategy getMarkPartitionFinishedStrategy() {
        return rp -> rp.isBlockingOrBlockingPersistentResultPartition() || this.hybridPartitionDataConsumeConstraint.isOnlyConsumeFinishedPartition();
    }

    private void tryComputeSourceParallelismThenRunAsync(BiConsumer<Void, Throwable> action) {
        FutureUtils.ConjunctFuture<Void> dynamicSourceParallelismFutures = FutureUtils.waitForAll(this.computeDynamicSourceParallelism());
        ((CompletableFuture)dynamicSourceParallelismFutures.whenCompleteAsync((BiConsumer)action, (Executor)this.getMainThreadExecutor())).exceptionally(throwable -> {
            this.log.error("An unexpected error occurred while scheduling.", throwable);
            this.handleGlobalFailure(new SuppressRestartsException((Throwable)throwable));
            return null;
        });
    }

    public List<CompletableFuture<Integer>> computeDynamicSourceParallelism() {
        ArrayList<CompletableFuture<Integer>> dynamicSourceParallelismFutures = new ArrayList<CompletableFuture<Integer>>();
        for (ExecutionJobVertex jobVertex : this.getExecutionGraph().getVerticesTopologically()) {
            List<SourceCoordinator<?, ?>> sourceCoordinators = jobVertex.getSourceCoordinators();
            if (sourceCoordinators.isEmpty() || jobVertex.isParallelismDecided()) continue;
            if (this.sourceParallelismFuturesByJobVertexId.containsKey(jobVertex.getJobVertexId())) {
                dynamicSourceParallelismFutures.add(this.sourceParallelismFuturesByJobVertexId.get(jobVertex.getJobVertexId()));
                continue;
            }
            Optional<List<BlockingResultInfo>> consumedResultsInfo = this.tryGetConsumedResultsInfo(jobVertex);
            if (!consumedResultsInfo.isPresent()) continue;
            List<CompletableFuture<Integer>> sourceParallelismFutures = sourceCoordinators.stream().map(sourceCoordinator -> sourceCoordinator.inferSourceParallelismAsync(this.vertexParallelismAndInputInfosDecider.computeSourceParallelismUpperBound(jobVertex.getJobVertexId(), jobVertex.getMaxParallelism()), this.vertexParallelismAndInputInfosDecider.getDataVolumePerTask())).collect(Collectors.toList());
            CompletableFuture<Integer> dynamicSourceParallelismFuture = AdaptiveBatchScheduler.mergeDynamicParallelismFutures(sourceParallelismFutures);
            this.sourceParallelismFuturesByJobVertexId.put(jobVertex.getJobVertexId(), dynamicSourceParallelismFuture);
            dynamicSourceParallelismFutures.add(dynamicSourceParallelismFuture);
        }
        return dynamicSourceParallelismFutures;
    }

    @VisibleForTesting
    static CompletableFuture<Integer> mergeDynamicParallelismFutures(List<CompletableFuture<Integer>> sourceParallelismFutures) {
        return sourceParallelismFutures.stream().reduce(CompletableFuture.completedFuture(-1), (a, b) -> a.thenCombine((CompletionStage)b, Math::max));
    }

    @VisibleForTesting
    public void initializeVerticesIfPossible() {
        ArrayList<ExecutionJobVertex> newlyInitializedJobVertices = new ArrayList<ExecutionJobVertex>();
        try {
            long createTimestamp = System.currentTimeMillis();
            for (ExecutionJobVertex jobVertex : this.getExecutionGraph().getVerticesTopologically()) {
                if (jobVertex.isInitialized()) continue;
                if (this.canInitialize(jobVertex)) {
                    this.getExecutionGraph().initializeJobVertex(jobVertex, createTimestamp);
                    newlyInitializedJobVertices.add(jobVertex);
                    continue;
                }
                Optional<List<BlockingResultInfo>> consumedResultsInfo = this.tryGetConsumedResultsInfo(jobVertex);
                if (!consumedResultsInfo.isPresent()) continue;
                ParallelismAndInputInfos parallelismAndInputInfos = this.tryDecideParallelismAndInputInfos(jobVertex, consumedResultsInfo.get());
                this.changeJobVertexParallelism(jobVertex, parallelismAndInputInfos.getParallelism());
                Preconditions.checkState(this.canInitialize(jobVertex));
                this.getExecutionGraph().initializeJobVertex(jobVertex, createTimestamp, parallelismAndInputInfos.getJobVertexInputInfos());
                newlyInitializedJobVertices.add(jobVertex);
            }
        }
        catch (JobException ex) {
            this.log.error("Unexpected error occurred when initializing ExecutionJobVertex", (Throwable)ex);
            this.handleGlobalFailure(new SuppressRestartsException(ex));
        }
        if (newlyInitializedJobVertices.size() > 0) {
            this.updateTopology(newlyInitializedJobVertices);
        }
    }

    private ParallelismAndInputInfos tryDecideParallelismAndInputInfos(ExecutionJobVertex jobVertex, List<BlockingResultInfo> inputs) {
        int vertexInitialParallelism = jobVertex.getParallelism();
        ForwardGroup forwardGroup = this.forwardGroupsByJobVertexId.get(jobVertex.getJobVertexId());
        if (!jobVertex.isParallelismDecided() && forwardGroup != null) {
            Preconditions.checkState(!forwardGroup.isParallelismDecided());
        }
        int vertexMinParallelism = -1;
        if (this.sourceParallelismFuturesByJobVertexId.containsKey(jobVertex.getJobVertexId())) {
            int dynamicSourceParallelism = this.getDynamicSourceParallelism(jobVertex);
            if (!inputs.isEmpty()) {
                vertexMinParallelism = dynamicSourceParallelism;
            } else {
                vertexInitialParallelism = dynamicSourceParallelism;
            }
        }
        ParallelismAndInputInfos parallelismAndInputInfos = this.vertexParallelismAndInputInfosDecider.decideParallelismAndInputInfosForVertex(jobVertex.getJobVertexId(), inputs, vertexInitialParallelism, vertexMinParallelism, jobVertex.getMaxParallelism());
        if (vertexInitialParallelism == -1) {
            this.log.info("Parallelism of JobVertex: {} ({}) is decided to be {}.", new Object[]{jobVertex.getName(), jobVertex.getJobVertexId(), parallelismAndInputInfos.getParallelism()});
        } else {
            Preconditions.checkState(parallelismAndInputInfos.getParallelism() == vertexInitialParallelism);
        }
        if (forwardGroup != null && !forwardGroup.isParallelismDecided()) {
            forwardGroup.setParallelism(parallelismAndInputInfos.getParallelism());
            for (JobVertexID jobVertexId : forwardGroup.getJobVertexIds()) {
                ExecutionJobVertex executionJobVertex = this.getExecutionJobVertex(jobVertexId);
                if (!executionJobVertex.isParallelismDecided()) {
                    this.log.info("Parallelism of JobVertex: {} ({}) is decided to be {} according to forward group's parallelism.", new Object[]{executionJobVertex.getName(), executionJobVertex.getJobVertexId(), parallelismAndInputInfos.getParallelism()});
                    this.changeJobVertexParallelism(executionJobVertex, parallelismAndInputInfos.getParallelism());
                    continue;
                }
                Preconditions.checkState(parallelismAndInputInfos.getParallelism() == executionJobVertex.getParallelism());
            }
        }
        return parallelismAndInputInfos;
    }

    private int getDynamicSourceParallelism(ExecutionJobVertex jobVertex) {
        CompletableFuture<Integer> dynamicSourceParallelismFuture = this.sourceParallelismFuturesByJobVertexId.get(jobVertex.getJobVertexId());
        int dynamicSourceParallelism = -1;
        if (dynamicSourceParallelismFuture != null) {
            int vertexMaxParallelism;
            dynamicSourceParallelism = dynamicSourceParallelismFuture.join();
            if (dynamicSourceParallelism > (vertexMaxParallelism = jobVertex.getMaxParallelism())) {
                this.log.info("The dynamic inferred source parallelism {} is larger than the maximum parallelism {}. Use {} as the upper bound parallelism of source job vertex {}.", new Object[]{dynamicSourceParallelism, vertexMaxParallelism, vertexMaxParallelism, jobVertex.getJobVertexId()});
                dynamicSourceParallelism = vertexMaxParallelism;
            } else if (dynamicSourceParallelism > 0) {
                this.log.info("Parallelism of JobVertex: {} ({}) is decided to be {} according to dynamic source parallelism inference.", new Object[]{jobVertex.getName(), jobVertex.getJobVertexId(), dynamicSourceParallelism});
            } else {
                dynamicSourceParallelism = -1;
            }
        }
        return dynamicSourceParallelism;
    }

    private void enrichInputBytesForExecutionVertices(List<ExecutionVertex> executionVertices) {
        for (ExecutionVertex ev : executionVertices) {
            List<IntermediateResult> intermediateResults = ev.getJobVertex().getInputs();
            boolean hasHybridEdge = intermediateResults.stream().anyMatch(ir -> ir.getResultType() == ResultPartitionType.HYBRID_FULL || ir.getResultType() == ResultPartitionType.HYBRID_SELECTIVE);
            if (intermediateResults.isEmpty() || hasHybridEdge) continue;
            long inputBytes = 0L;
            for (IntermediateResult intermediateResult : intermediateResults) {
                ExecutionVertexInputInfo inputInfo = ev.getExecutionVertexInputInfo(intermediateResult.getId());
                IndexRange partitionIndexRange = inputInfo.getPartitionIndexRange();
                IndexRange subpartitionIndexRange = inputInfo.getSubpartitionIndexRange();
                BlockingResultInfo blockingResultInfo = Preconditions.checkNotNull(this.getBlockingResultInfo(intermediateResult.getId()));
                inputBytes += blockingResultInfo.getNumBytesProduced(partitionIndexRange, subpartitionIndexRange);
            }
            ev.setInputBytes(inputBytes);
        }
    }

    private void changeJobVertexParallelism(ExecutionJobVertex jobVertex, int parallelism) {
        if (jobVertex.isParallelismDecided()) {
            return;
        }
        jobVertex.getJobVertex().setParallelism(parallelism);
        try {
            this.getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(this.getJobGraph()));
        }
        catch (Throwable t) {
            this.log.warn("Cannot create JSON plan for job", t);
            this.getExecutionGraph().setJsonPlan("{}");
        }
        jobVertex.setParallelism(parallelism);
    }

    private Optional<List<BlockingResultInfo>> tryGetConsumedResultsInfo(ExecutionJobVertex jobVertex) {
        ArrayList<BlockingResultInfo> consumableResultInfo = new ArrayList<BlockingResultInfo>();
        DefaultLogicalVertex logicalVertex = this.logicalTopology.getVertex(jobVertex.getJobVertexId());
        Iterable<DefaultLogicalResult> consumedResults = logicalVertex.getConsumedResults();
        for (DefaultLogicalResult consumedResult : consumedResults) {
            ExecutionJobVertex producerVertex = this.getExecutionJobVertex(consumedResult.getProducer().getId());
            if (producerVertex.isFinished()) {
                BlockingResultInfo resultInfo = Preconditions.checkNotNull(this.blockingResultInfos.get(consumedResult.getId()));
                consumableResultInfo.add(resultInfo);
                continue;
            }
            return Optional.empty();
        }
        return Optional.of(consumableResultInfo);
    }

    private boolean canInitialize(ExecutionJobVertex jobVertex) {
        if (jobVertex.isInitialized() || !jobVertex.isParallelismDecided()) {
            return false;
        }
        for (JobEdge inputEdge : jobVertex.getJobVertex().getInputs()) {
            ExecutionJobVertex producerVertex = this.getExecutionGraph().getJobVertex(inputEdge.getSource().getProducer().getID());
            Preconditions.checkNotNull(producerVertex);
            if (producerVertex.isInitialized()) continue;
            return false;
        }
        return true;
    }

    private void updateTopology(List<ExecutionJobVertex> newlyInitializedJobVertices) {
        for (ExecutionJobVertex vertex : newlyInitializedJobVertices) {
            this.initializeOperatorCoordinatorsFor(vertex);
        }
        this.getExecutionGraph().notifyNewlyInitializedJobVertices(newlyInitializedJobVertices);
    }

    private void initializeOperatorCoordinatorsFor(ExecutionJobVertex vertex) {
        this.operatorCoordinatorHandler.registerAndStartNewCoordinators(vertex.getOperatorCoordinators(), this.getMainThreadExecutor(), vertex.getParallelism());
    }

    @VisibleForTesting
    public static VertexParallelismStore computeVertexParallelismStoreForDynamicGraph(Iterable<JobVertex> vertices, int defaultMaxParallelism) {
        return AdaptiveBatchScheduler.computeVertexParallelismStore(vertices, v -> {
            if (v.getParallelism() > 0) {
                return AdaptiveBatchScheduler.getDefaultMaxParallelism(v);
            }
            return defaultMaxParallelism;
        }, Function.identity());
    }

    private static BlockingResultInfo createFromIntermediateResult(IntermediateResult result) {
        Preconditions.checkArgument(result != null);
        if (result.getConsumingDistributionPattern() == DistributionPattern.POINTWISE) {
            return new PointwiseBlockingResultInfo(result.getId(), result.getNumberOfAssignedPartitions(), result.getPartitions()[0].getNumberOfSubpartitions());
        }
        return new AllToAllBlockingResultInfo(result.getId(), result.getNumberOfAssignedPartitions(), result.getPartitions()[0].getNumberOfSubpartitions(), result.isBroadcast());
    }

    @VisibleForTesting
    BlockingResultInfo getBlockingResultInfo(IntermediateDataSetID resultId) {
        return this.blockingResultInfos.get(resultId);
    }
}

