/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.io.network.netty;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.apache.flink.runtime.clusterframework.types.ResourceID;
import org.apache.flink.runtime.io.network.ConnectionID;
import org.apache.flink.runtime.io.network.NetworkClientHandler;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
import org.apache.flink.runtime.io.network.TaskEventPublisher;
import org.apache.flink.runtime.io.network.netty.NettyMessage;
import org.apache.flink.runtime.io.network.netty.NettyPartitionRequestClient;
import org.apache.flink.runtime.io.network.netty.NettyProtocol;
import org.apache.flink.runtime.io.network.netty.NettyTestUtil;
import org.apache.flink.runtime.io.network.netty.PartitionRequestClientFactory;
import org.apache.flink.runtime.io.network.netty.exception.LocalTransportException;
import org.apache.flink.runtime.io.network.netty.exception.RemoteTransportException;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionProvider;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOutboundHandlerAdapter;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelPromise;
import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
import org.apache.flink.testutils.TestingUtils;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.NotThrownAssert;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import org.mockito.verification.VerificationMode;

class ClientTransportErrorHandlingTest {
    private static final ConnectionID CONNECTION_ID = new ConnectionID(ResourceID.generate(), new InetSocketAddress("localhost", 0), 0);

    ClientTransportErrorHandlingTest() {
    }

    @Test
    void testExceptionOnWrite() throws Exception {
        NettyProtocol protocol = new NettyProtocol((ResultPartitionProvider)Mockito.mock(ResultPartitionProvider.class), (TaskEventPublisher)Mockito.mock(TaskEventDispatcher.class)){

            public ChannelHandler[] getServerChannelHandlers() {
                return new ChannelHandler[0];
            }
        };
        NettyTestUtil.NettyServerAndClient serverAndClient = NettyTestUtil.initServerAndClient(protocol);
        Channel ch = NettyTestUtil.connect(serverAndClient);
        NetworkClientHandler handler = this.getClientHandler(ch);
        ch.pipeline().addFirst(new ChannelHandler[]{new ChannelOutboundHandlerAdapter(){
            int writeNum = 0;

            public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
                if (this.writeNum >= 1) {
                    throw new RuntimeException("Expected test exception.");
                }
                ++this.writeNum;
                ctx.write(msg, promise);
            }
        }});
        NettyPartitionRequestClient requestClient = new NettyPartitionRequestClient(ch, handler, CONNECTION_ID, (PartitionRequestClientFactory)Mockito.mock(PartitionRequestClientFactory.class));
        RemoteInputChannel[] rich = new RemoteInputChannel[]{this.createRemoteInputChannel(), this.createRemoteInputChannel()};
        CountDownLatch sync = new CountDownLatch(1);
        ((RemoteInputChannel)Mockito.doAnswer(invocation -> {
            sync.countDown();
            return null;
        }).when((Object)rich[1])).onError((Throwable)ArgumentMatchers.isA(LocalTransportException.class));
        requestClient.requestSubpartition(new ResultPartitionID(), 0, rich[0], 0);
        requestClient.requestSubpartition(new ResultPartitionID(), 0, rich[1], 0);
        ((AbstractBooleanAssert)Assertions.assertThat((boolean)sync.await(TestingUtils.TESTING_DURATION.toMillis(), TimeUnit.MILLISECONDS)).withFailMessage("Timed out after waiting for " + TestingUtils.TESTING_DURATION.toMillis() + " ms to be notified about the channel error.", new Object[0])).isTrue();
        ((RemoteInputChannel)Mockito.verify((Object)rich[0], (VerificationMode)Mockito.times((int)0))).onError((Throwable)ArgumentMatchers.any(LocalTransportException.class));
        NettyTestUtil.shutdown(serverAndClient);
    }

    @Test
    void testWrappingOfRemoteErrorMessage() throws Exception {
        RemoteInputChannel[] rich;
        EmbeddedChannel ch = this.createEmbeddedChannel();
        NetworkClientHandler handler = this.getClientHandler((Channel)ch);
        for (RemoteInputChannel r : rich = new RemoteInputChannel[]{this.createRemoteInputChannel(), this.createRemoteInputChannel()}) {
            Mockito.when((Object)r.getInputChannelId()).thenReturn((Object)new InputChannelID());
            handler.addInputChannel(r);
        }
        ch.pipeline().fireChannelRead((Object)new NettyMessage.ErrorResponse((Throwable)new RuntimeException("Expected test exception"), rich[0].getInputChannelId()));
        ((NotThrownAssert)Assertions.assertThatNoException().describedAs("The exception reached the end of the pipeline and was not handled correctly by the last handler.", new Object[0])).isThrownBy(() -> ((EmbeddedChannel)ch).checkException());
        ((RemoteInputChannel)Mockito.verify((Object)rich[0], (VerificationMode)Mockito.times((int)1))).onError((Throwable)ArgumentMatchers.isA(RemoteTransportException.class));
        ((RemoteInputChannel)Mockito.verify((Object)rich[1], (VerificationMode)Mockito.never())).onError((Throwable)ArgumentMatchers.any(Throwable.class));
        ch.pipeline().fireChannelRead((Object)new NettyMessage.ErrorResponse((Throwable)new RuntimeException("Expected test exception")));
        ((NotThrownAssert)Assertions.assertThatNoException().describedAs("The exception reached the end of the pipeline and was not handled correctly by the last handler.", new Object[0])).isThrownBy(() -> ((EmbeddedChannel)ch).checkException());
        ((RemoteInputChannel)Mockito.verify((Object)rich[0], (VerificationMode)Mockito.times((int)2))).onError((Throwable)ArgumentMatchers.isA(RemoteTransportException.class));
        ((RemoteInputChannel)Mockito.verify((Object)rich[1], (VerificationMode)Mockito.times((int)1))).onError((Throwable)ArgumentMatchers.isA(RemoteTransportException.class));
    }

    @Test
    void testExceptionOnRemoteClose() throws Exception {
        NettyProtocol protocol = new NettyProtocol((ResultPartitionProvider)Mockito.mock(ResultPartitionProvider.class), (TaskEventPublisher)Mockito.mock(TaskEventDispatcher.class)){

            public ChannelHandler[] getServerChannelHandlers() {
                return new ChannelHandler[]{new ChannelInboundHandlerAdapter(){

                    public void channelRead(ChannelHandlerContext ctx, Object msg) {
                        ctx.channel().close();
                    }
                }};
            }
        };
        NettyTestUtil.NettyServerAndClient serverAndClient = NettyTestUtil.initServerAndClient(protocol);
        Channel ch = NettyTestUtil.connect(serverAndClient);
        NetworkClientHandler handler = this.getClientHandler(ch);
        RemoteInputChannel[] rich = new RemoteInputChannel[]{this.createRemoteInputChannel(), this.createRemoteInputChannel()};
        CountDownLatch sync = new CountDownLatch(rich.length);
        Answer countDownLatch = invocation -> {
            sync.countDown();
            return null;
        };
        for (RemoteInputChannel r : rich) {
            ((RemoteInputChannel)Mockito.doAnswer((Answer)countDownLatch).when((Object)r)).onError((Throwable)ArgumentMatchers.any(Throwable.class));
            handler.addInputChannel(r);
        }
        ch.writeAndFlush((Object)Unpooled.buffer().writerIndex(16));
        ((AbstractBooleanAssert)Assertions.assertThat((boolean)sync.await(TestingUtils.TESTING_DURATION.toMillis(), TimeUnit.MILLISECONDS)).withFailMessage("Timed out after waiting for " + TestingUtils.TESTING_DURATION.toMillis() + " ms to be notified about remote connection close.", new Object[0])).isTrue();
        for (RemoteInputChannel r : rich) {
            ((RemoteInputChannel)Mockito.verify((Object)r)).onError((Throwable)ArgumentMatchers.isA(RemoteTransportException.class));
        }
        NettyTestUtil.shutdown(serverAndClient);
    }

    @Test
    void testExceptionCaught() throws Exception {
        RemoteInputChannel[] rich;
        EmbeddedChannel ch = this.createEmbeddedChannel();
        NetworkClientHandler handler = this.getClientHandler((Channel)ch);
        for (RemoteInputChannel r : rich = new RemoteInputChannel[]{this.createRemoteInputChannel(), this.createRemoteInputChannel()}) {
            Mockito.when((Object)r.getInputChannelId()).thenReturn((Object)new InputChannelID());
            handler.addInputChannel(r);
        }
        ch.pipeline().fireExceptionCaught((Throwable)new Exception());
        ((NotThrownAssert)Assertions.assertThatNoException().describedAs("The exception reached the end of the pipeline and was not handled correctly by the last handler.", new Object[0])).isThrownBy(() -> ((EmbeddedChannel)ch).checkException());
        for (RemoteInputChannel r : rich) {
            ((RemoteInputChannel)Mockito.verify((Object)r)).onError((Throwable)ArgumentMatchers.isA(LocalTransportException.class));
        }
    }

    @Test
    void testConnectionResetByPeer() throws Throwable {
        EmbeddedChannel ch = this.createEmbeddedChannel();
        NetworkClientHandler handler = this.getClientHandler((Channel)ch);
        RemoteInputChannel rich = this.addInputChannel(handler);
        Throwable[] error = new Throwable[1];
        ((RemoteInputChannel)Mockito.doAnswer(invocation -> {
            Throwable cause = (Throwable)invocation.getArguments()[0];
            try {
                Assertions.assertThat((Throwable)cause).isInstanceOf(RemoteTransportException.class);
                Assertions.assertThat((Throwable)cause).hasMessageNotContaining("Connection reset by peer");
                Assertions.assertThat((Throwable)cause.getCause()).isInstanceOf(IOException.class);
                Assertions.assertThat((Throwable)cause.getCause()).hasMessage("Connection reset by peer");
            }
            catch (Throwable t) {
                error[0] = t;
            }
            return null;
        }).when((Object)rich)).onError((Throwable)ArgumentMatchers.any(Throwable.class));
        ch.pipeline().fireExceptionCaught((Throwable)new IOException("Connection reset by peer"));
        Assertions.assertThat((Throwable)error[0]).isNull();
    }

    @Test
    void testChannelClosedOnExceptionDuringErrorNotification() throws Exception {
        EmbeddedChannel ch = this.createEmbeddedChannel();
        NetworkClientHandler handler = this.getClientHandler((Channel)ch);
        RemoteInputChannel rich = this.addInputChannel(handler);
        ((RemoteInputChannel)Mockito.doThrow((Throwable[])new Throwable[]{new RuntimeException("Expected test exception")}).when((Object)rich)).onError((Throwable)ArgumentMatchers.any(Throwable.class));
        ch.pipeline().fireExceptionCaught((Throwable)new Exception());
        Assertions.assertThat((boolean)ch.isActive()).isFalse();
    }

    private EmbeddedChannel createEmbeddedChannel() {
        NettyProtocol protocol = new NettyProtocol((ResultPartitionProvider)Mockito.mock(ResultPartitionProvider.class), (TaskEventPublisher)Mockito.mock(TaskEventDispatcher.class));
        return new EmbeddedChannel(protocol.getClientChannelHandlers());
    }

    private RemoteInputChannel addInputChannel(NetworkClientHandler clientHandler) throws IOException {
        RemoteInputChannel rich = this.createRemoteInputChannel();
        clientHandler.addInputChannel(rich);
        return rich;
    }

    private NetworkClientHandler getClientHandler(Channel ch) {
        NetworkClientHandler networkClientHandler = (NetworkClientHandler)ch.pipeline().get(NetworkClientHandler.class);
        networkClientHandler.setConnectionId(CONNECTION_ID);
        return networkClientHandler;
    }

    private RemoteInputChannel createRemoteInputChannel() {
        return (RemoteInputChannel)Mockito.when((Object)((RemoteInputChannel)Mockito.mock(RemoteInputChannel.class)).getInputChannelId()).thenReturn((Object)new InputChannelID()).getMock();
    }
}

