/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.integration.ip.tcp.connection;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSession;
import org.jspecify.annotations.Nullable;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.integration.ip.tcp.connection.TcpNioConnection;
import org.springframework.messaging.MessagingException;
import org.springframework.util.Assert;

public class TcpNioSSLConnection
extends TcpNioConnection {
    private static final int DEFAULT_HANDSHAKE_TIMEOUT = 30;
    private final SSLEngine sslEngine;
    private ByteBuffer decoded;
    private ByteBuffer encoded;
    private final Semaphore semaphore = new Semaphore(0);
    private final Lock monitorLock = new ReentrantLock();
    private int handshakeTimeout = 30;
    private boolean needMoreNetworkData;
    private @Nullable SSLHandshakeException sslFatal;
    private volatile @Nullable SSLChannelOutputStream sslChannelOutputStream;
    private volatile boolean writerActive;

    public TcpNioSSLConnection(SocketChannel socketChannel, boolean server, boolean lookupHost, @Nullable ApplicationEventPublisher applicationEventPublisher, @Nullable String connectionFactoryName, SSLEngine sslEngine) {
        super(socketChannel, server, lookupHost, applicationEventPublisher, connectionFactoryName);
        this.sslEngine = sslEngine;
    }

    public void setHandshakeTimeout(int handshakeTimeout) {
        this.handshakeTimeout = handshakeTimeout;
    }

    @Override
    public SSLSession getSslSession() {
        return this.sslEngine.getSession();
    }

    @Override
    protected void sendToPipe(ByteBuffer networkBuffer) throws IOException {
        Assert.notNull((Object)networkBuffer, (String)"rawBuffer cannot be null");
        if (this.logger.isDebugEnabled()) {
            this.logger.debug((Object)("sendToPipe " + String.valueOf((Object)this.sslEngine.getHandshakeStatus()) + ", remaining: " + networkBuffer.remaining()));
        }
        SSLEngineResult result = null;
        while (!this.needMoreNetworkData) {
            try {
                result = this.decode(networkBuffer);
            }
            catch (SSLHandshakeException e) {
                this.sslFatal = e;
                this.semaphore.release();
                throw e;
            }
            if (!this.logger.isDebugEnabled()) continue;
            this.logger.debug((Object)("result " + this.resultToString(result) + ", remaining: " + networkBuffer.remaining()));
        }
        this.needMoreNetworkData = false;
        if (result != null && SSLEngineResult.Status.BUFFER_UNDERFLOW == result.getStatus()) {
            networkBuffer.compact();
        } else {
            networkBuffer.clear();
        }
        if (this.logger.isDebugEnabled()) {
            this.logger.debug((Object)("sendToPipe.x " + this.resultToString(result) + ", remaining: " + networkBuffer.remaining()));
        }
    }

    private SSLEngineResult decode(ByteBuffer networkBuffer) throws IOException {
        SSLEngineResult.HandshakeStatus handshakeStatus = this.sslEngine.getHandshakeStatus();
        SSLEngineResult result = new SSLEngineResult(SSLEngineResult.Status.OK, handshakeStatus, 0, 0);
        switch (handshakeStatus) {
            case NEED_TASK: {
                this.runTasks();
                break;
            }
            case NEED_WRAP: {
                result = this.needWrap(networkBuffer, result);
                break;
            }
            default: {
                result = this.checkBytesProduced(networkBuffer);
            }
        }
        switch (result.getHandshakeStatus()) {
            case FINISHED: {
                this.resumeWriterIfNeeded();
            }
            case NOT_HANDSHAKING: 
            case NEED_UNWRAP: {
                this.needMoreNetworkData = result.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW || networkBuffer.remaining() == 0;
                break;
            }
        }
        return result;
    }

    private SSLEngineResult checkBytesProduced(ByteBuffer networkBuffer) throws IOException {
        SSLEngineResult.Status status;
        this.decoded.clear();
        SSLEngineResult result = this.sslEngine.unwrap(networkBuffer, this.decoded);
        if (this.logger.isDebugEnabled()) {
            this.logger.debug((Object)("After unwrap: " + this.resultToString(result)));
        }
        if ((status = result.getStatus()) == SSLEngineResult.Status.BUFFER_OVERFLOW) {
            this.decoded = this.allocateEncryptionBuffer(this.sslEngine.getSession().getApplicationBufferSize());
        }
        if (result.bytesProduced() > 0) {
            this.decoded.flip();
            super.sendToPipe(this.decoded);
        }
        return result;
    }

    private SSLEngineResult needWrap(ByteBuffer networkBuffer, SSLEngineResult result) throws IOException {
        SSLEngineResult engineResult = result;
        if (!this.resumeWriterIfNeeded()) {
            this.encoded.clear();
            engineResult = this.sslEngine.wrap(networkBuffer, this.encoded);
            if (this.logger.isDebugEnabled()) {
                this.logger.debug((Object)("After wrap: " + this.resultToString(engineResult)));
            }
            if (engineResult.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
                this.encoded = this.allocateEncryptionBuffer(this.sslEngine.getSession().getPacketBufferSize());
            } else {
                this.encoded.flip();
                this.getSSLChannelOutputStream().writeEncoded(this.encoded);
            }
        }
        return engineResult;
    }

    private boolean resumeWriterIfNeeded() {
        if (this.writerActive) {
            if (this.logger.isTraceEnabled()) {
                this.logger.trace((Object)("Waking sender, permits: " + this.semaphore.availablePermits()));
            }
            this.semaphore.release();
            return true;
        }
        return false;
    }

    private void runTasks() {
        Runnable task;
        while ((task = this.sslEngine.getDelegatedTask()) != null) {
            task.run();
        }
    }

    private SSLEngineResult.HandshakeStatus runTasksIfNeeded(SSLEngineResult result) {
        if (this.logger.isDebugEnabled()) {
            this.logger.debug((Object)("Running tasks if needed " + this.resultToString(result)));
        }
        if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) {
            this.runTasks();
        }
        SSLEngineResult.HandshakeStatus handshakeStatus = this.sslEngine.getHandshakeStatus();
        if (this.logger.isDebugEnabled()) {
            this.logger.debug((Object)("New handshake status " + String.valueOf((Object)handshakeStatus)));
        }
        return handshakeStatus;
    }

    public void init() {
        this.decoded = this.allocateEncryptionBuffer(2048);
        this.encoded = this.allocateEncryptionBuffer(2048);
        this.initializeEngine();
    }

    private ByteBuffer allocateEncryptionBuffer(int size) {
        if (this.isUsingDirectBuffers()) {
            return ByteBuffer.allocateDirect(size);
        }
        return ByteBuffer.allocate(size);
    }

    private void initializeEngine() {
        boolean client = !this.isServer();
        this.sslEngine.setUseClientMode(client);
    }

    @Override
    protected TcpNioConnection.ChannelOutputStream getChannelOutputStream() {
        this.monitorLock.lock();
        try {
            SSLChannelOutputStream sslChannelOutputStreamToUse = this.sslChannelOutputStream;
            if (sslChannelOutputStreamToUse == null) {
                this.sslChannelOutputStream = sslChannelOutputStreamToUse = new SSLChannelOutputStream(super.getChannelOutputStream());
            }
            SSLChannelOutputStream sSLChannelOutputStream = sslChannelOutputStreamToUse;
            return sSLChannelOutputStream;
        }
        finally {
            this.monitorLock.unlock();
        }
    }

    protected SSLChannelOutputStream getSSLChannelOutputStream() {
        SSLChannelOutputStream sslChannelOutputStreamToReturn = this.sslChannelOutputStream;
        return sslChannelOutputStreamToReturn != null ? sslChannelOutputStreamToReturn : (SSLChannelOutputStream)this.getChannelOutputStream();
    }

    private String resultToString(@Nullable SSLEngineResult result) {
        return result != null ? result.toString().replace('\n', ' ') : "null";
    }

    @Override
    public void close() {
        super.close();
        this.logger.trace((Object)"Resuming for close");
        this.semaphore.release();
    }

    protected final class SSLChannelOutputStream
    extends TcpNioConnection.ChannelOutputStream {
        private final Lock lock = new ReentrantLock();
        private final TcpNioConnection.ChannelOutputStream channelOutputStream;

        SSLChannelOutputStream(TcpNioConnection.ChannelOutputStream channelOutputStream) {
            this.channelOutputStream = channelOutputStream;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        protected void doWrite(ByteBuffer plainText) throws IOException {
            this.lock.lock();
            try {
                TcpNioSSLConnection.this.writerActive = true;
                int remaining = plainText.remaining();
                while (remaining > 0) {
                    SSLEngineResult result = this.encode(plainText);
                    if (TcpNioSSLConnection.this.logger.isDebugEnabled()) {
                        TcpNioSSLConnection.this.logger.debug((Object)("doWrite: " + TcpNioSSLConnection.this.resultToString(result)));
                    }
                    if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
                        this.writeEncodedIfAny();
                        if (plainText.remaining() >= remaining) {
                            throw new MessagingException("Unexpected condition - SSL wrap did not consume any data; remaining = " + remaining);
                        }
                        remaining = plainText.remaining();
                        continue;
                    }
                    this.doClientSideHandshake(plainText, result);
                    this.writeEncodedIfAny();
                }
            }
            finally {
                TcpNioSSLConnection.this.writerActive = false;
                this.lock.unlock();
            }
        }

        private void doClientSideHandshake(ByteBuffer plainText, SSLEngineResult resultArg) throws IOException {
            SSLEngineResult result = resultArg;
            TcpNioSSLConnection.this.semaphore.drainPermits();
            SSLEngineResult.HandshakeStatus status = TcpNioSSLConnection.this.sslEngine.getHandshakeStatus();
            while (status != SSLEngineResult.HandshakeStatus.FINISHED) {
                if (TcpNioSSLConnection.this.logger.isTraceEnabled()) {
                    TcpNioSSLConnection.this.logger.trace((Object)("Handshake Status: " + String.valueOf((Object)status)));
                }
                this.writeEncodedIfAny();
                status = TcpNioSSLConnection.this.runTasksIfNeeded(result);
                if (status == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
                    status = this.waitForHandshakeData(result);
                }
                if (status == SSLEngineResult.HandshakeStatus.NEED_WRAP || status == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING || status == SSLEngineResult.HandshakeStatus.FINISHED) {
                    result = this.encode(plainText);
                    status = result.getHandshakeStatus();
                    if (status != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING && status != SSLEngineResult.HandshakeStatus.FINISHED) continue;
                    break;
                }
                TcpNioSSLConnection.this.logger.debug((Object)status);
            }
            if (TcpNioSSLConnection.this.logger.isTraceEnabled()) {
                TcpNioSSLConnection.this.logger.trace((Object)("Handshake Status: " + String.valueOf((Object)status)));
            }
        }

        private void writeEncodedIfAny() throws IOException {
            TcpNioSSLConnection.this.encoded.flip();
            this.writeEncoded(TcpNioSSLConnection.this.encoded);
            TcpNioSSLConnection.this.encoded.clear();
        }

        private SSLEngineResult.HandshakeStatus waitForHandshakeData(SSLEngineResult result) throws IOException {
            try {
                TcpNioSSLConnection.this.logger.trace((Object)"Writer waiting for handshake");
                if (!TcpNioSSLConnection.this.semaphore.tryAcquire(TcpNioSSLConnection.this.handshakeTimeout, TimeUnit.SECONDS)) {
                    throw new MessagingException("SSL Handshaking taking too long");
                }
                if (TcpNioSSLConnection.this.sslFatal != null) {
                    throw TcpNioSSLConnection.this.sslFatal;
                }
                if (!TcpNioSSLConnection.this.isOpen()) {
                    throw new IOException("Socket closed during SSL Handshake");
                }
                TcpNioSSLConnection.this.logger.trace((Object)"Writer resuming handshake");
                return TcpNioSSLConnection.this.runTasksIfNeeded(result);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new MessagingException("Interrupted during SSL Handshaking", (Throwable)e);
            }
        }

        private SSLEngineResult encode(ByteBuffer plainText) throws IOException {
            TcpNioSSLConnection.this.encoded.clear();
            SSLEngineResult result = TcpNioSSLConnection.this.sslEngine.wrap(plainText, TcpNioSSLConnection.this.encoded);
            if (TcpNioSSLConnection.this.logger.isDebugEnabled()) {
                TcpNioSSLConnection.this.logger.debug((Object)("After wrap: " + TcpNioSSLConnection.this.resultToString(result) + " Plaintext buffer @" + plainText.position() + "/" + plainText.limit()));
            }
            if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
                TcpNioSSLConnection.this.encoded = TcpNioSSLConnection.this.allocateEncryptionBuffer(TcpNioSSLConnection.this.sslEngine.getSession().getPacketBufferSize());
                result = TcpNioSSLConnection.this.sslEngine.wrap(plainText, TcpNioSSLConnection.this.encoded);
            }
            return result;
        }

        void writeEncoded(ByteBuffer encoded) throws IOException {
            this.channelOutputStream.doWrite(encoded);
        }
    }
}

