/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.flink.api.common.RuntimeExecutionMode;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobmaster.event.ExecutionJobVertexFinishedEvent;
import org.apache.flink.runtime.jobmaster.event.JobEvent;
import org.apache.flink.runtime.scheduler.adaptivebatch.DefaultAdaptiveExecutionHandler;
import org.apache.flink.runtime.scheduler.adaptivebatch.JobGraphUpdateListener;
import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished;
import org.apache.flink.runtime.scheduler.adaptivebatch.StreamGraphOptimizationStrategy;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.graph.StreamEdge;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.streaming.api.graph.StreamGraphContext;
import org.apache.flink.streaming.api.graph.StreamNode;
import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.DynamicCodeLoadingException;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

class DefaultAdaptiveExecutionHandlerTest {
    @RegisterExtension
    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension();

    DefaultAdaptiveExecutionHandlerTest() {
    }

    @Test
    void testGetJobGraph() throws DynamicCodeLoadingException {
        JobGraph jobGraph = this.createAdaptiveExecutionHandler().getJobGraph();
        Assertions.assertThat((Object)jobGraph).isNotNull();
        Assertions.assertThat((int)jobGraph.getNumberOfVertices()).isOne();
        Assertions.assertThat((String)((JobVertex)jobGraph.getVertices().iterator().next()).getName()).contains(new CharSequence[]{"Source"});
    }

    @Test
    void testHandleJobEvent() throws DynamicCodeLoadingException {
        ArrayList newAddedJobVertices = new ArrayList();
        AtomicInteger pendingOperators = new AtomicInteger();
        DefaultAdaptiveExecutionHandler handler = this.createAdaptiveExecutionHandler((newVertices, pendingOperatorsCount) -> {
            newAddedJobVertices.addAll(newVertices);
            pendingOperators.set(pendingOperatorsCount);
        }, this.createStreamGraph());
        JobGraph jobGraph = handler.getJobGraph();
        JobVertex source = jobGraph.getVerticesSortedTopologicallyFromSources().stream().filter(jobVertex -> jobVertex.getName().contains("Source")).findFirst().get();
        ExecutionJobVertexFinishedEvent event1 = new ExecutionJobVertexFinishedEvent(source.getID(), Collections.emptyMap());
        handler.handleJobEvent((JobEvent)event1);
        Assertions.assertThat(newAddedJobVertices).hasSize(1);
        Assertions.assertThat((String)((JobVertex)newAddedJobVertices.get(0)).getName()).contains(new CharSequence[]{"Map"});
        Assertions.assertThat((int)pendingOperators.get()).isOne();
        ExecutionJobVertexFinishedEvent event2 = new ExecutionJobVertexFinishedEvent(((JobVertex)newAddedJobVertices.get(0)).getID(), Collections.emptyMap());
        handler.handleJobEvent((JobEvent)event2);
        Assertions.assertThat(newAddedJobVertices).hasSize(2);
        Assertions.assertThat((String)((JobVertex)newAddedJobVertices.get(1)).getName()).contains(new CharSequence[]{"Sink"});
        Assertions.assertThat((int)pendingOperators.get()).isZero();
    }

    @Test
    void testOptimizeStreamGraph() throws DynamicCodeLoadingException {
        StreamGraph streamGraph = this.createStreamGraph();
        StreamNode source = streamGraph.getStreamNodes().stream().filter(node -> node.getOperatorName().contains("Source")).findFirst().get();
        StreamNode map = streamGraph.getStreamNodes().stream().filter(node -> node.getOperatorName().contains("Map")).findFirst().get();
        Assertions.assertThat((Object)((StreamEdge)source.getOutEdges().get(0)).getPartitioner()).isInstanceOf(ForwardPartitioner.class);
        Assertions.assertThat((Object)((StreamEdge)map.getOutEdges().get(0)).getPartitioner()).isInstanceOf(RescalePartitioner.class);
        streamGraph.getJobConfiguration().set(StreamGraphOptimizationStrategy.STREAM_GRAPH_OPTIMIZATION_STRATEGY, Collections.singletonList(TestingStreamGraphOptimizerStrategy.class.getName()));
        TestingStreamGraphOptimizerStrategy.convertToReBalanceEdgeIds.add(((StreamEdge)source.getOutEdges().get(0)).getEdgeId());
        TestingStreamGraphOptimizerStrategy.convertToReBalanceEdgeIds.add(((StreamEdge)map.getOutEdges().get(0)).getEdgeId());
        DefaultAdaptiveExecutionHandler handler = this.createAdaptiveExecutionHandler((newVertices, pendingOperatorsCount) -> {}, streamGraph);
        JobGraph jobGraph = handler.getJobGraph();
        JobVertex sourceVertex = (JobVertex)jobGraph.getVertices().iterator().next();
        ExecutionJobVertexFinishedEvent event1 = new ExecutionJobVertexFinishedEvent(sourceVertex.getID(), Collections.emptyMap());
        handler.handleJobEvent((JobEvent)event1);
        Assertions.assertThat((List)((IntermediateDataSet)sourceVertex.getProducedDataSets().get(0)).getConsumers()).hasSize(1);
        Assertions.assertThat((String)((JobEdge)((IntermediateDataSet)sourceVertex.getProducedDataSets().get(0)).getConsumers().get(0)).getShipStrategyName()).isEqualToIgnoringCase((CharSequence)"forward");
        Iterator jobVertexIterator = jobGraph.getVertices().iterator();
        jobVertexIterator.next();
        JobVertex mapVertex = (JobVertex)jobVertexIterator.next();
        ExecutionJobVertexFinishedEvent event2 = new ExecutionJobVertexFinishedEvent(mapVertex.getID(), Collections.emptyMap());
        handler.handleJobEvent((JobEvent)event2);
        Assertions.assertThat((List)((IntermediateDataSet)mapVertex.getProducedDataSets().get(0)).getConsumers()).hasSize(1);
        Assertions.assertThat((String)((JobEdge)((IntermediateDataSet)mapVertex.getProducedDataSets().get(0)).getConsumers().get(0)).getShipStrategyName()).isEqualToIgnoringCase((CharSequence)"rebalance");
    }

    @Test
    void testGetInitialParallelismAndNotifyJobVertexParallelismDecided() throws DynamicCodeLoadingException {
        StreamGraph streamGraph = this.createStreamGraph();
        DefaultAdaptiveExecutionHandler handler = this.createAdaptiveExecutionHandler((newVertices, pendingOperatorsCount) -> {}, streamGraph);
        JobGraph jobGraph = handler.getJobGraph();
        JobVertex source = jobGraph.getVerticesSortedTopologicallyFromSources().stream().filter(jobVertex -> jobVertex.getName().contains("Source")).findFirst().get();
        Assertions.assertThat((int)handler.getInitialParallelism(source.getID())).isEqualTo(source.getParallelism());
        Random random = new Random();
        int parallelism = 1 + random.nextInt(8);
        handler.notifyJobVertexParallelismDecided(source.getID(), parallelism);
        handler.handleJobEvent((JobEvent)new ExecutionJobVertexFinishedEvent(source.getID(), Collections.emptyMap()));
        JobVertex map = jobGraph.getVerticesSortedTopologicallyFromSources().stream().filter(jobVertex -> jobVertex.getName().contains("Map")).findFirst().get();
        Assertions.assertThat((int)handler.getInitialParallelism(map.getID())).isEqualTo(parallelism);
    }

    private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler() throws DynamicCodeLoadingException {
        return this.createAdaptiveExecutionHandler((newVertices, pendingOperatorsCount) -> {}, this.createStreamGraph());
    }

    private StreamGraph createStreamGraph() {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.disableOperatorChaining();
        env.setRuntimeMode(RuntimeExecutionMode.BATCH);
        env.fromSequence(0L, 1L).name("Source").forward().map((MapFunction & Serializable)i -> i).name("Map").rescale().print().name("Sink").disableChaining();
        env.setParallelism(1);
        return env.getStreamGraph();
    }

    private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler(JobGraphUpdateListener listener, StreamGraph streamGraph) throws DynamicCodeLoadingException {
        DefaultAdaptiveExecutionHandler handler = new DefaultAdaptiveExecutionHandler(this.getClass().getClassLoader(), streamGraph, (Executor)EXECUTOR_RESOURCE.getExecutor());
        handler.registerJobGraphUpdateListener(listener);
        return handler;
    }

    public static final class TestingStreamGraphOptimizerStrategy
    implements StreamGraphOptimizationStrategy {
        private static final Set<String> convertToReBalanceEdgeIds = new HashSet<String>();

        public boolean onOperatorsFinished(OperatorsFinished operatorsFinished, StreamGraphContext context) {
            List finishedStreamNodeIds = operatorsFinished.getFinishedStreamNodeIds();
            ArrayList<StreamEdgeUpdateRequestInfo> requestInfos = new ArrayList<StreamEdgeUpdateRequestInfo>();
            for (Integer finishedStreamNodeId : finishedStreamNodeIds) {
                for (ImmutableStreamEdge outEdge : context.getStreamGraph().getStreamNode(finishedStreamNodeId).getOutEdges()) {
                    if (!convertToReBalanceEdgeIds.contains(outEdge.getEdgeId())) continue;
                    StreamEdgeUpdateRequestInfo requestInfo = new StreamEdgeUpdateRequestInfo(outEdge.getEdgeId(), Integer.valueOf(outEdge.getSourceId()), Integer.valueOf(outEdge.getTargetId()));
                    requestInfo.withOutputPartitioner((StreamPartitioner)new RebalancePartitioner());
                    requestInfos.add(requestInfo);
                }
            }
            return context.modifyStreamEdge(requestInfos);
        }
    }
}

