/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.jrt;

import com.yahoo.jrt.Buffer;
import com.yahoo.jrt.CryptoSocket;
import com.yahoo.jrt.TransportMetrics;
import com.yahoo.security.tls.ConnectionAuthContext;
import com.yahoo.security.tls.PeerAuthorizationFailedException;
import com.yahoo.security.tls.TransportSecurityUtils;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel;
import java.util.Objects;
import java.util.logging.Logger;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSession;

public class TlsCryptoSocket
implements CryptoSocket {
    private static final ByteBuffer NULL_BUFFER = ByteBuffer.allocate(0);
    private static final Logger log = Logger.getLogger(TlsCryptoSocket.class.getName());
    private final TransportMetrics metrics = TransportMetrics.getInstance();
    private final SocketChannel channel;
    private final SSLEngine sslEngine;
    private final Buffer wrapBuffer;
    private final Buffer unwrapBuffer;
    private int sessionPacketBufferSize;
    private int sessionApplicationBufferSize;
    private ByteBuffer handshakeDummyBuffer;
    private HandshakeState handshakeState;
    private ConnectionAuthContext authContext;

    public TlsCryptoSocket(SocketChannel channel, SSLEngine sslEngine) {
        this.channel = channel;
        this.sslEngine = sslEngine;
        this.wrapBuffer = new Buffer(0);
        this.unwrapBuffer = new Buffer(0);
        SSLSession nullSession = sslEngine.getSession();
        this.sessionApplicationBufferSize = nullSession.getApplicationBufferSize();
        this.sessionPacketBufferSize = nullSession.getPacketBufferSize();
        this.handshakeDummyBuffer = ByteBuffer.allocate(this.sessionApplicationBufferSize);
        this.handshakeState = HandshakeState.NOT_STARTED;
        log.fine(() -> "Initialized with " + sslEngine.toString());
    }

    public void injectReadData(Buffer data) {
        this.unwrapBuffer.getWritable(data.bytes()).put(data.getReadable());
    }

    @Override
    public SocketChannel channel() {
        return this.channel;
    }

    @Override
    public CryptoSocket.HandshakeResult handshake() throws IOException {
        HandshakeState newHandshakeState = this.processHandshakeState(this.handshakeState);
        log.fine(() -> String.format("Handshake state '%s -> %s'", new Object[]{this.handshakeState, newHandshakeState}));
        this.handshakeState = newHandshakeState;
        return TlsCryptoSocket.toHandshakeResult(newHandshakeState);
    }

    @Override
    public void doHandshakeWork() {
        Runnable task;
        while ((task = this.sslEngine.getDelegatedTask()) != null) {
            task.run();
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private HandshakeState processHandshakeState(HandshakeState state) throws IOException {
        try {
            switch (state) {
                case NOT_STARTED: {
                    log.fine(() -> "Initiating handshake");
                    this.sslEngine.beginHandshake();
                    break;
                }
                case NEED_WRITE: {
                    this.channelWrite();
                    break;
                }
                case NEED_READ: {
                    this.channelRead();
                    break;
                }
                case NEED_WORK: {
                    break;
                }
                case COMPLETED: {
                    return HandshakeState.COMPLETED;
                }
                default: {
                    throw TlsCryptoSocket.unhandledStateException(state);
                }
            }
            block15: while (true) {
                log.fine(() -> "SSLEngine.getHandshakeStatus(): " + this.sslEngine.getHandshakeStatus());
                switch (this.sslEngine.getHandshakeStatus()) {
                    case NOT_HANDSHAKING: {
                        if (this.wrapBuffer.bytes() > 0) {
                            return HandshakeState.NEED_WRITE;
                        }
                        this.sslEngine.setEnableSessionCreation(false);
                        this.handshakeDummyBuffer = null;
                        SSLSession session = this.sslEngine.getSession();
                        this.sessionApplicationBufferSize = session.getApplicationBufferSize();
                        this.sessionPacketBufferSize = session.getPacketBufferSize();
                        this.authContext = (ConnectionAuthContext)TransportSecurityUtils.getConnectionAuthContext((SSLSession)session).orElseThrow();
                        if (!this.authContext.authorized()) {
                            this.metrics.incrementPeerAuthorizationFailures();
                        }
                        log.fine(() -> String.format("Handshake complete: protocol=%s, cipherSuite=%s", session.getProtocol(), session.getCipherSuite()));
                        if (this.sslEngine.getUseClientMode()) {
                            this.metrics.incrementClientTlsConnectionsEstablished();
                            return HandshakeState.COMPLETED;
                        } else {
                            this.metrics.incrementServerTlsConnectionsEstablished();
                        }
                        return HandshakeState.COMPLETED;
                    }
                    case NEED_TASK: {
                        return HandshakeState.NEED_WORK;
                    }
                    case NEED_UNWRAP: {
                        if (this.wrapBuffer.bytes() > 0) {
                            return HandshakeState.NEED_WRITE;
                        }
                        if (this.handshakeUnwrap()) continue block15;
                        return HandshakeState.NEED_READ;
                    }
                    case NEED_WRAP: {
                        if (!this.handshakeWrap()) return HandshakeState.NEED_WRITE;
                        continue block15;
                    }
                }
                break;
            }
            throw new IllegalStateException("Unexpected handshake status: " + this.sslEngine.getHandshakeStatus());
        }
        catch (SSLHandshakeException e) {
            if (e.getCause() instanceof PeerAuthorizationFailedException) throw e;
            this.metrics.incrementTlsCertificateVerificationFailures();
            throw e;
        }
    }

    private static CryptoSocket.HandshakeResult toHandshakeResult(HandshakeState state) {
        switch (state) {
            case NEED_READ: {
                return CryptoSocket.HandshakeResult.NEED_READ;
            }
            case NEED_WRITE: {
                return CryptoSocket.HandshakeResult.NEED_WRITE;
            }
            case NEED_WORK: {
                return CryptoSocket.HandshakeResult.NEED_WORK;
            }
            case COMPLETED: {
                return CryptoSocket.HandshakeResult.DONE;
            }
        }
        throw TlsCryptoSocket.unhandledStateException(state);
    }

    @Override
    public int getMinimumReadBufferSize() {
        return this.sessionApplicationBufferSize;
    }

    @Override
    public int read(ByteBuffer dst) throws IOException {
        this.verifyHandshakeCompleted();
        int bytesUnwrapped = this.drain(dst);
        if (bytesUnwrapped > 0) {
            return bytesUnwrapped;
        }
        int bytesRead = this.channelRead();
        if (bytesRead == 0) {
            return 0;
        }
        return this.drain(dst);
    }

    @Override
    public int drain(ByteBuffer dst) throws IOException {
        this.verifyHandshakeCompleted();
        int totalBytesUnwrapped = 0;
        int result;
        while ((result = this.applicationDataUnwrap(dst)) >= 0) {
            totalBytesUnwrapped += result;
        }
        return totalBytesUnwrapped;
    }

    @Override
    public int write(ByteBuffer src) throws IOException {
        int bytesWrapped;
        this.verifyHandshakeCompleted();
        if (this.flush() == CryptoSocket.FlushResult.NEED_WRITE) {
            return 0;
        }
        int totalBytesWrapped = 0;
        do {
            bytesWrapped = this.applicationDataWrap(src);
            totalBytesWrapped += bytesWrapped;
        } while (bytesWrapped > 0 && this.wrapBuffer.bytes() < this.sessionPacketBufferSize);
        return totalBytesWrapped;
    }

    @Override
    public CryptoSocket.FlushResult flush() throws IOException {
        this.verifyHandshakeCompleted();
        this.channelWrite();
        return this.wrapBuffer.bytes() > 0 ? CryptoSocket.FlushResult.NEED_WRITE : CryptoSocket.FlushResult.DONE;
    }

    @Override
    public void dropEmptyBuffers() {
        this.wrapBuffer.shrink(0);
        this.unwrapBuffer.shrink(0);
    }

    @Override
    public ConnectionAuthContext connectionAuthContext() {
        if (this.handshakeState != HandshakeState.COMPLETED) {
            throw new IllegalStateException("Handshake not complete");
        }
        return Objects.requireNonNull(this.authContext);
    }

    private boolean handshakeWrap() throws IOException {
        SSLEngineResult result = this.sslEngineWrap(NULL_BUFFER);
        switch (result.getStatus()) {
            case OK: {
                return true;
            }
            case BUFFER_OVERFLOW: {
                this.sessionPacketBufferSize = this.sslEngine.getSession().getPacketBufferSize();
                return false;
            }
        }
        throw TlsCryptoSocket.unexpectedStatusException(result.getStatus());
    }

    private int applicationDataWrap(ByteBuffer src) throws IOException {
        SSLEngineResult result = this.sslEngineWrap(src);
        TlsCryptoSocket.failIfRenegotiationDetected(result);
        switch (result.getStatus()) {
            case OK: {
                return result.bytesConsumed();
            }
            case BUFFER_OVERFLOW: {
                return 0;
            }
        }
        throw TlsCryptoSocket.unexpectedStatusException(result.getStatus());
    }

    private SSLEngineResult sslEngineWrap(ByteBuffer src) throws IOException {
        SSLEngineResult result = this.sslEngine.wrap(src, this.wrapBuffer.getWritable(this.sessionPacketBufferSize));
        TlsCryptoSocket.failIfCloseSignalDetected(result);
        return result;
    }

    private boolean handshakeUnwrap() throws IOException {
        SSLEngineResult result = this.sslEngineUnwrap(this.handshakeDummyBuffer);
        switch (result.getStatus()) {
            case OK: {
                if (result.bytesProduced() > 0) {
                    throw new SSLException("Got application data in handshake unwrap");
                }
                return true;
            }
            case BUFFER_UNDERFLOW: {
                return false;
            }
        }
        throw TlsCryptoSocket.unexpectedStatusException(result.getStatus());
    }

    private int applicationDataUnwrap(ByteBuffer dst) throws IOException {
        SSLEngineResult result = this.sslEngineUnwrap(dst);
        TlsCryptoSocket.failIfRenegotiationDetected(result);
        switch (result.getStatus()) {
            case OK: {
                return result.bytesProduced();
            }
            case BUFFER_OVERFLOW: 
            case BUFFER_UNDERFLOW: {
                return -1;
            }
        }
        throw TlsCryptoSocket.unexpectedStatusException(result.getStatus());
    }

    private SSLEngineResult sslEngineUnwrap(ByteBuffer dst) throws IOException {
        SSLEngineResult result = this.sslEngine.unwrap(this.unwrapBuffer.getReadable(), dst);
        TlsCryptoSocket.failIfCloseSignalDetected(result);
        return result;
    }

    private int channelRead() throws IOException {
        int read = this.channel.read(this.unwrapBuffer.getWritable(this.sessionPacketBufferSize));
        if (read == -1) {
            throw new ClosedChannelException();
        }
        return read;
    }

    private int channelWrite() throws IOException {
        return this.channel.write(this.wrapBuffer.getReadable());
    }

    private static void failIfCloseSignalDetected(SSLEngineResult result) throws ClosedChannelException {
        if (result.getStatus() == SSLEngineResult.Status.CLOSED) {
            throw new ClosedChannelException();
        }
    }

    private static void failIfRenegotiationDetected(SSLEngineResult result) throws SSLException {
        if (result.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING && result.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.FINISHED) {
            throw new SSLException("Renegotiation detected");
        }
    }

    private static IllegalStateException unhandledStateException(HandshakeState state) {
        return new IllegalStateException("Unhandled state: " + state);
    }

    private static IllegalStateException unexpectedStatusException(SSLEngineResult.Status status) {
        return new IllegalStateException("Unexpected status: " + status);
    }

    private void verifyHandshakeCompleted() throws SSLException {
        if (this.handshakeState != HandshakeState.COMPLETED) {
            throw new SSLException("Handshake not completed: handshakeState=" + this.handshakeState);
        }
    }

    private static enum HandshakeState {
        NOT_STARTED,
        NEED_READ,
        NEED_WRITE,
        NEED_WORK,
        COMPLETED;

    }
}

