/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.driver.internal.security;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLHandshakeException;
import org.neo4j.driver.internal.net.BoltServerAddress;
import org.neo4j.driver.internal.security.SecurityPlan;
import org.neo4j.driver.internal.util.BytePrinter;
import org.neo4j.driver.v1.Logger;
import org.neo4j.driver.v1.exceptions.ClientException;
import org.neo4j.driver.v1.exceptions.SecurityException;
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;

public class TLSSocketChannel
implements ByteChannel {
    private final ByteChannel channel;
    private final Logger logger;
    private final BoltServerAddress address;
    private SSLEngine sslEngine;
    private ByteBuffer cipherOut;
    private ByteBuffer cipherIn;
    private ByteBuffer plainIn;
    private ByteBuffer plainOut;
    private static final ByteBuffer DUMMY_BUFFER = ByteBuffer.allocate(0);

    public static TLSSocketChannel create(BoltServerAddress address, SecurityPlan securityPlan, ByteChannel channel, Logger logger) throws IOException {
        SSLEngine sslEngine = securityPlan.sslContext().createSSLEngine(address.host(), address.port());
        sslEngine.setUseClientMode(true);
        return TLSSocketChannel.create(channel, logger, sslEngine, address);
    }

    public static TLSSocketChannel create(ByteChannel channel, Logger logger, SSLEngine sslEngine, BoltServerAddress address) throws IOException {
        TLSSocketChannel tlsChannel = new TLSSocketChannel(channel, logger, sslEngine, address);
        try {
            tlsChannel.runHandshake();
        }
        catch (SSLHandshakeException e) {
            throw new SecurityException("Failed to establish secured connection with the server: " + e.getMessage(), e);
        }
        return tlsChannel;
    }

    TLSSocketChannel(ByteChannel channel, Logger logger, SSLEngine sslEngine, BoltServerAddress address) throws IOException {
        this.address = address;
        this.logger = logger;
        this.channel = channel;
        this.sslEngine = sslEngine;
        this.plainIn = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize());
        this.cipherIn = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize());
        this.plainOut = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize());
        this.cipherOut = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize());
    }

    private void runHandshake() throws IOException {
        this.logger.debug("Running TLS handshake", new Object[0]);
        this.sslEngine.beginHandshake();
        SSLEngineResult.HandshakeStatus handshakeStatus = this.sslEngine.getHandshakeStatus();
        while (handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED && handshakeStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
            switch (handshakeStatus) {
                case NEED_TASK: {
                    handshakeStatus = this.runDelegatedTasks();
                    break;
                }
                case NEED_UNWRAP: {
                    handshakeStatus = this.unwrap(DUMMY_BUFFER);
                    break;
                }
                case NEED_WRAP: {
                    handshakeStatus = this.wrap(this.plainOut);
                }
            }
        }
        this.logger.debug("TLS handshake completed", new Object[0]);
    }

    private SSLEngineResult.HandshakeStatus runDelegatedTasks() {
        Runnable runnable;
        while ((runnable = this.sslEngine.getDelegatedTask()) != null) {
            runnable.run();
        }
        return this.sslEngine.getHandshakeStatus();
    }

    int channelRead(ByteBuffer toBuffer) throws IOException {
        int read = this.channel.read(toBuffer);
        if (read < 0) {
            try {
                this.channel.close();
            }
            catch (IOException iOException) {
                // empty catch block
            }
            throw new ServiceUnavailableException("Failed to receive any data from the connected address " + this.address + ". Please ensure a working connection to the database.");
        }
        return read;
    }

    int channelWrite(ByteBuffer fromBuffer) throws IOException {
        int written = this.channel.write(fromBuffer);
        if (written < 0) {
            try {
                this.channel.close();
            }
            catch (IOException iOException) {
                // empty catch block
            }
            throw new ServiceUnavailableException("Failed to send any data to the connected address " + this.address + ". Please ensure a working connection to the database.");
        }
        return written;
    }

    private SSLEngineResult.HandshakeStatus unwrap(ByteBuffer buffer) throws IOException {
        SSLEngineResult.HandshakeStatus handshakeStatus = this.sslEngine.getHandshakeStatus();
        this.channelRead(this.cipherIn);
        this.cipherIn.flip();
        do {
            SSLEngineResult unwrapResult = this.sslEngine.unwrap(this.cipherIn, this.plainIn);
            SSLEngineResult.Status status = unwrapResult.getStatus();
            switch (status) {
                case OK: {
                    this.plainIn.flip();
                    TLSSocketChannel.bufferCopy(this.plainIn, buffer);
                    this.plainIn.compact();
                    handshakeStatus = this.runDelegatedTasks();
                    break;
                }
                case BUFFER_OVERFLOW: {
                    this.enlargeApplicationInputBuffer();
                    break;
                }
                case BUFFER_UNDERFLOW: {
                    this.enlargeNetworkInputBuffer();
                    return handshakeStatus;
                }
                case CLOSED: {
                    this.sslEngine.closeInbound();
                    break;
                }
                default: {
                    throw new ClientException("Got unexpected status " + (Object)((Object)status) + " while reading encrypted data.");
                }
            }
        } while (this.cipherIn.hasRemaining());
        this.cipherIn.compact();
        return handshakeStatus;
    }

    private SSLEngineResult.HandshakeStatus wrap(ByteBuffer buffer) throws IOException, ClientException {
        SSLEngineResult.HandshakeStatus handshakeStatus = this.sslEngine.getHandshakeStatus();
        SSLEngineResult.Status status = this.sslEngine.wrap(buffer, this.cipherOut).getStatus();
        switch (status) {
            case OK: {
                handshakeStatus = this.runDelegatedTasks();
                this.cipherOut.flip();
                while (this.cipherOut.hasRemaining()) {
                    this.channelWrite(this.cipherOut);
                }
                this.cipherOut.clear();
                break;
            }
            case BUFFER_OVERFLOW: {
                this.enlargeNetworkOutBuffer();
                break;
            }
            case CLOSED: {
                this.sslEngine.closeOutbound();
                throw new ServiceUnavailableException("Encrypted connection closed while writing encrypted data.");
            }
            default: {
                throw new ClientException("Got unexpected status " + (Object)((Object)status) + " while writing encrypted data.");
            }
        }
        return handshakeStatus;
    }

    private static int bufferCopy(ByteBuffer from, ByteBuffer to) {
        int maxTransfer = Math.min(to.remaining(), from.remaining());
        ByteBuffer temporaryBuffer = from.duplicate();
        temporaryBuffer.limit(temporaryBuffer.position() + maxTransfer);
        to.put(temporaryBuffer);
        from.position(from.position() + maxTransfer);
        return maxTransfer;
    }

    @Override
    public int read(ByteBuffer dst) throws IOException {
        int toRead = dst.remaining();
        this.plainIn.flip();
        if (this.plainIn.hasRemaining()) {
            TLSSocketChannel.bufferCopy(this.plainIn, dst);
            this.plainIn.compact();
        } else {
            this.plainIn.clear();
            this.unwrap(dst);
        }
        return toRead - dst.remaining();
    }

    @Override
    public int write(ByteBuffer src) throws IOException {
        int toWrite = src.remaining();
        while (src.remaining() > 0) {
            this.wrap(src);
        }
        return toWrite;
    }

    @Override
    public boolean isOpen() {
        return this.channel.isOpen();
    }

    @Override
    public void close() throws IOException {
        try {
            this.plainOut.clear();
            this.sslEngine.closeOutbound();
            while (!this.sslEngine.isOutboundDone()) {
                SSLEngineResult res = this.sslEngine.wrap(this.plainOut, this.cipherOut);
                this.cipherOut.flip();
                while (this.cipherOut.hasRemaining()) {
                    this.channelWrite(this.cipherOut);
                }
                this.cipherOut.clear();
            }
            this.channel.close();
            this.logger.debug("Closed secure channel", new Object[0]);
        }
        catch (IOException e) {
            this.logger.error("TLS socket could not be closed cleanly", e);
        }
    }

    private void enlargeNetworkInputBuffer() {
        int curNetSize = this.cipherIn.capacity();
        int netSize = this.sslEngine.getSession().getPacketBufferSize();
        if (netSize > curNetSize) {
            ByteBuffer newCipherIn = ByteBuffer.allocate(netSize);
            newCipherIn.put(this.cipherIn);
            this.cipherIn = newCipherIn;
            this.logger.debug("Enlarged network input buffer from %s to %s. This operation should be a rare operation.", curNetSize, netSize);
        } else {
            this.cipherIn.compact();
        }
    }

    private void enlargeApplicationInputBuffer() {
        this.plainIn.flip();
        int curAppSize = this.plainIn.capacity();
        int appSize = this.sslEngine.getSession().getApplicationBufferSize();
        int newAppSize = appSize + this.plainIn.remaining();
        if (newAppSize > appSize * 2) {
            throw new ClientException(String.format("Failed ro enlarge application input buffer from %s to %s, as the maximum buffer size allowed is %s. The content in the buffer is: %s\n", curAppSize, newAppSize, appSize * 2, BytePrinter.hex(this.plainIn)));
        }
        ByteBuffer newPlainIn = ByteBuffer.allocate(newAppSize);
        newPlainIn.put(this.plainIn);
        this.plainIn = newPlainIn;
        this.logger.debug("Enlarged application input buffer from %s to %s. This operation should be a rare operation.", curAppSize, newAppSize);
    }

    private void enlargeNetworkOutBuffer() throws IOException {
        int curNetSize = this.cipherOut.capacity();
        int netSize = this.sslEngine.getSession().getPacketBufferSize();
        if (netSize > curNetSize) {
            this.cipherOut = ByteBuffer.allocate(netSize);
            this.logger.debug("Enlarged network output buffer from %s to %s. This operation should be a rare operation.", curNetSize, netSize);
        } else {
            int written;
            this.logger.debug("Network output buffer doesn't need enlarging, flushing data to the channel instead to open up space on the buffer.", new Object[0]);
            this.cipherOut.flip();
            while (this.cipherOut.hasRemaining() && (written = this.channelWrite(this.cipherOut)) <= 0) {
                this.logger.debug("having difficulty flushing data (network contention on local computer?). will continue trying after yielding execution.", new Object[0]);
                Thread.yield();
                this.logger.debug("nothing written to the underlying channel (network output buffer is full?), will try till we can.", new Object[0]);
            }
            this.cipherOut.compact();
        }
    }

    public String toString() {
        return "TLSSocketChannel{plainIn: " + this.plainIn + ", cipherIn:" + this.cipherIn + "}";
    }
}

