package org.jgroups.protocols;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.security.KeyStore;
import java.util.Map;
import java.util.Objects;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import org.jgroups.Address;
import org.jgroups.Event;
import org.jgroups.Global;
import org.jgroups.View;
import org.jgroups.annotations.LocalAddress;
import org.jgroups.annotations.MBean;
import org.jgroups.annotations.Property;
import org.jgroups.stack.IpAddress;
import org.jgroups.util.Runner;
import org.jgroups.util.Tuple;
import org.jgroups.util.Util;

@MBean(description = "Key exchange protocol based on an SSL connection between secret key requester and provider (key server) to fetch a shared secret group key from the key server. That shared (symmetric) key is subsequently used to encrypt communication between cluster members")
/* loaded from: input_file:lib/jgroups-4.1.6.Final.jar:org/jgroups/protocols/SSL_KEY_EXCHANGE.class */
public class SSL_KEY_EXCHANGE extends KeyExchange {

    @Property(description = "Bind address for the server or client socket. The following special values are also recognized: GLOBAL, SITE_LOCAL, LINK_LOCAL and NON_LOOPBACK", systemProperty = {Global.BIND_ADDR})
    @LocalAddress
    protected InetAddress bind_addr;

    @Property(description = "The port at which the key server is listening. If the port is not available, the next port will be probed, up to port+port_range. Used by the key server (server) to create an SSLServerSocket and by clients to connect to the key server.")
    protected int port = 2157;

    @Property(description = "The port range to probe")
    protected int port_range = 5;

    @Property(description = "Location of the keystore")
    protected String keystore_name = "keystore.jks";

    @Property(description = "The type of the keystore. Types are listed in http://docs.oracle.com/javase/8/docs/technotes/tools/unix/keytool.html")
    protected String keystore_type = "JKS";

    @Property(description = "Password to access the keystore", exposeAsManagedAttribute = false)
    protected String keystore_password = "changeit";

    @Property(description = "The type of secret key to be sent up the stack (converted from DH). Should be the same as the algorithm part of ASYM_ENCRYPT.sym_algorithm if ASYM_ENCRYPT is used")
    protected String secret_key_algorithm = "AES";

    @Property(description = "If enabled, clients are authenticated as well (not just the server). Set to true to prevent rogue clients to fetch the secret group key (e.g. via man-in-the-middle attacks)")
    protected boolean require_client_authentication = true;

    @Property(description = "Timeout (in ms) for a socket read. This applies for example to the initial SSL handshake, e.g. if the client connects to a non-JGroups service accidentally running on the same port")
    protected int socket_timeout = 1000;

    @Property(description = "The fully qualified name of a class implementing SessionVerifier")
    protected String session_verifier_class;

    @Property(description = "The argument to the session verifier")
    protected String session_verifier_arg;
    protected SSLContext client_ssl_ctx;
    protected SSLContext server_ssl_ctx;
    protected SSLServerSocket srv_sock;
    protected Runner srv_sock_handler;
    protected KeyStore key_store;
    protected View view;
    protected SessionVerifier session_verifier;

    /* loaded from: input_file:lib/jgroups-4.1.6.Final.jar:org/jgroups/protocols/SSL_KEY_EXCHANGE$SessionVerifier.class */
    public interface SessionVerifier {
        void init(String str);

        void verify(SSLSession sSLSession) throws SecurityException;
    }

    /* loaded from: input_file:lib/jgroups-4.1.6.Final.jar:org/jgroups/protocols/SSL_KEY_EXCHANGE$Type.class */
    protected enum Type {
        SECRET_KEY_REQ,
        SECRET_KEY_RSP
    }

    public InetAddress getBindAddress() {
        return this.bind_addr;
    }

    public SSL_KEY_EXCHANGE setBindAddress(InetAddress inetAddress) {
        this.bind_addr = inetAddress;
        return this;
    }

    public int getPort() {
        return this.port;
    }

    public SSL_KEY_EXCHANGE setPort(int i) {
        this.port = i;
        return this;
    }

    public int getPortRange() {
        return this.port_range;
    }

    public SSL_KEY_EXCHANGE setPortRange(int i) {
        this.port_range = i;
        return this;
    }

    public String getKeystoreName() {
        return this.keystore_name;
    }

    public SSL_KEY_EXCHANGE setKeystoreName(String str) {
        this.keystore_name = str;
        return this;
    }

    public String getKeystoreType() {
        return this.keystore_type;
    }

    public SSL_KEY_EXCHANGE setKeystoreType(String str) {
        this.keystore_type = str;
        return this;
    }

    public String getKeystorePassword() {
        return this.keystore_password;
    }

    public SSL_KEY_EXCHANGE setKeystorePassword(String str) {
        this.keystore_password = str;
        return this;
    }

    public String getSecretKeyAlgorithm() {
        return this.secret_key_algorithm;
    }

    public SSL_KEY_EXCHANGE setSecretKeyAlgorithm(String str) {
        this.secret_key_algorithm = str;
        return this;
    }

    public boolean getRequireClientAuthentication() {
        return this.require_client_authentication;
    }

    public SSL_KEY_EXCHANGE setRequireClientAuthentication(boolean z) {
        this.require_client_authentication = z;
        return this;
    }

    public int getSocketTimeout() {
        return this.socket_timeout;
    }

    public SSL_KEY_EXCHANGE setSocketTimeout(int i) {
        this.socket_timeout = i;
        return this;
    }

    public String getSessionVerifierClass() {
        return this.session_verifier_class;
    }

    public SSL_KEY_EXCHANGE setSessionVerifierClass(String str) {
        this.session_verifier_class = str;
        return this;
    }

    public String getSessionVerifierArg() {
        return this.session_verifier_arg;
    }

    public SSL_KEY_EXCHANGE setSessionVerifierArg(String str) {
        this.session_verifier_arg = str;
        return this;
    }

    public KeyStore getKeystore() {
        return this.key_store;
    }

    public SSL_KEY_EXCHANGE setKeystore(KeyStore keyStore) {
        this.key_store = keyStore;
        return this;
    }

    public SessionVerifier getSessionVerifier() {
        return this.session_verifier;
    }

    public SSL_KEY_EXCHANGE setSessionVerifier(SessionVerifier sessionVerifier) {
        this.session_verifier = sessionVerifier;
        return this;
    }

    @Deprecated
    public SSLContext getSSLContext() {
        return this.client_ssl_ctx;
    }

    @Deprecated
    public SSL_KEY_EXCHANGE setSSLContext(SSLContext sSLContext) {
        this.client_ssl_ctx = sSLContext;
        return this;
    }

    public SSLContext getClientSSLContext() {
        return this.client_ssl_ctx;
    }

    public SSL_KEY_EXCHANGE setClientSSLContext(SSLContext sSLContext) {
        this.client_ssl_ctx = sSLContext;
        return this;
    }

    public SSLContext getServerSSLContext() {
        return this.server_ssl_ctx;
    }

    public SSL_KEY_EXCHANGE setServerSSLContext(SSLContext sSLContext) {
        this.server_ssl_ctx = sSLContext;
        return this;
    }

    @Override // org.jgroups.protocols.KeyExchange
    public Address getServerLocation() {
        if (this.srv_sock == null) {
            return null;
        }
        return new IpAddress(getTransport().getBindAddress(), this.srv_sock.getLocalPort());
    }

    @Override // org.jgroups.stack.Protocol
    public void init() throws Exception {
        InputStream resourceAsStream;
        super.init();
        if (this.port == 0) {
            throw new IllegalStateException("port must not be 0 or else clients will not be able to connect");
        }
        ASYM_ENCRYPT asym_encrypt = (ASYM_ENCRYPT) findProtocolAbove(ASYM_ENCRYPT.class);
        if (asym_encrypt != null) {
            String symKeyAlgorithm = asym_encrypt.symKeyAlgorithm();
            if (!Util.match(symKeyAlgorithm, this.secret_key_algorithm)) {
                this.log.warn("%s: overriding %s=%s to %s from %s", "secret_key_algorithm", this.local_addr, this.secret_key_algorithm, symKeyAlgorithm, ASYM_ENCRYPT.class.getSimpleName());
                this.secret_key_algorithm = symKeyAlgorithm;
            }
        }
        if (this.key_store == null && (this.client_ssl_ctx == null || this.server_ssl_ctx == null)) {
            this.key_store = KeyStore.getInstance(this.keystore_type != null ? this.keystore_type : KeyStore.getDefaultType());
            try {
                resourceAsStream = new FileInputStream(this.keystore_name);
            } catch (FileNotFoundException e) {
                resourceAsStream = Util.getResourceAsStream(this.keystore_name, getClass());
            }
            if (resourceAsStream == null) {
                throw new FileNotFoundException(this.keystore_name);
            }
            this.key_store.load(resourceAsStream, this.keystore_password.toCharArray());
        }
        if (this.session_verifier != null || this.session_verifier_class == null) {
            return;
        }
        this.session_verifier = (SessionVerifier) Util.loadClass(this.session_verifier_class, getClass()).getDeclaredConstructor(new Class[0]).newInstance(new Object[0]);
        if (this.session_verifier_arg != null) {
            this.session_verifier.init(this.session_verifier_arg);
        }
    }

    @Override // org.jgroups.stack.Protocol
    public void start() throws Exception {
        super.start();
    }

    @Override // org.jgroups.stack.Protocol
    public void stop() {
        super.stop();
        stopKeyserver();
    }

    @Override // org.jgroups.stack.Protocol
    public void destroy() {
        super.destroy();
    }

    @Override // org.jgroups.stack.Protocol, org.jgroups.UpHandler
    public Object up(Event event) {
        if (event.getType() != 56) {
            return super.up(event);
        }
        if (this.bind_addr == null) {
            this.bind_addr = (InetAddress) ((Map) event.getArg()).get("bind_addr");
        }
        return this.up_prot.up(event);
    }

    @Override // org.jgroups.protocols.KeyExchange
    public void fetchSecretKeyFrom(Address address) throws Exception {
        SSLSocket createSocketTo = createSocketTo(address);
        Throwable th = null;
        try {
            DataInputStream dataInputStream = new DataInputStream(createSocketTo.getInputStream());
            OutputStream outputStream = createSocketTo.getOutputStream();
            outputStream.write(Type.SECRET_KEY_REQ.ordinal());
            outputStream.flush();
            byte readByte = dataInputStream.readByte();
            if (Type.values()[readByte] != Type.SECRET_KEY_RSP) {
                throw new IllegalStateException(String.format("expected response of %s but got type=%d", Type.SECRET_KEY_RSP, Byte.valueOf(readByte)));
            }
            byte[] bArr = new byte[dataInputStream.readInt()];
            dataInputStream.readFully(bArr);
            byte[] bArr2 = new byte[dataInputStream.readInt()];
            dataInputStream.readFully(bArr2);
            Tuple tuple = new Tuple(new SecretKeySpec(bArr2, this.secret_key_algorithm), bArr);
            this.log.debug("%s: sending up secret key (version: %s)", this.local_addr, Util.byteArrayToHexString(bArr));
            this.up_prot.up(new Event(Event.SET_SECRET_KEY, tuple));
            if (createSocketTo != null) {
                if (0 == 0) {
                    createSocketTo.close();
                    return;
                }
                try {
                    createSocketTo.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (createSocketTo != null) {
                if (0 != 0) {
                    try {
                        createSocketTo.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    createSocketTo.close();
                }
            }
            throw th3;
        }
    }

    protected void accept() {
        SSLSocket sSLSocket;
        Throwable th;
        DataOutputStream dataOutputStream;
        byte read;
        try {
            sSLSocket = (SSLSocket) this.srv_sock.accept();
            th = null;
            try {
                sSLSocket.setEnabledCipherSuites(sSLSocket.getSupportedCipherSuites());
                sSLSocket.startHandshake();
                SSLSession session = sSLSocket.getSession();
                this.log.debug("%s: accepted SSL connection from %s; protocol: %s, cipher suite: %s", this.local_addr, sSLSocket.getRemoteSocketAddress(), session.getProtocol(), session.getCipherSuite());
                if (this.session_verifier != null) {
                    this.session_verifier.verify(session);
                }
                InputStream inputStream = sSLSocket.getInputStream();
                dataOutputStream = new DataOutputStream(sSLSocket.getOutputStream());
                read = (byte) inputStream.read();
            } catch (Throwable th2) {
                if (sSLSocket != null) {
                    if (0 != 0) {
                        try {
                            sSLSocket.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    } else {
                        sSLSocket.close();
                    }
                }
                throw th2;
            }
        } catch (Throwable th4) {
            this.log.trace("%s: failure handling client socket: %s", this.local_addr, th4.getMessage());
        }
        if (Type.values()[read] != Type.SECRET_KEY_REQ) {
            throw new IllegalStateException(String.format("expected request of %s but got type=%d", Type.SECRET_KEY_REQ, Byte.valueOf(read)));
        }
        Tuple tuple = (Tuple) this.up_prot.up(new Event(Event.GET_SECRET_KEY));
        if (tuple == null) {
            if (sSLSocket != null) {
                if (0 == 0) {
                    sSLSocket.close();
                    return;
                }
                try {
                    sSLSocket.close();
                    return;
                } catch (Throwable th5) {
                    th.addSuppressed(th5);
                    return;
                }
            }
            return;
        }
        byte[] bArr = (byte[]) tuple.getVal2();
        byte[] encoded = ((SecretKey) tuple.getVal1()).getEncoded();
        dataOutputStream.write(Type.SECRET_KEY_RSP.ordinal());
        dataOutputStream.writeInt(bArr.length);
        dataOutputStream.write(bArr, 0, bArr.length);
        dataOutputStream.writeInt(encoded.length);
        dataOutputStream.write(encoded);
        if (sSLSocket != null) {
            if (0 != 0) {
                try {
                    sSLSocket.close();
                } catch (Throwable th6) {
                    th.addSuppressed(th6);
                }
            } else {
                sSLSocket.close();
            }
        }
        return;
        this.log.trace("%s: failure handling client socket: %s", this.local_addr, th4.getMessage());
    }

    @Override // org.jgroups.protocols.KeyExchange
    protected void handleView(View view) {
        Address coord = this.view != null ? this.view.getCoord() : null;
        this.view = view;
        if (!Objects.equals(view.getCoord(), this.local_addr)) {
            if (Objects.equals(coord, this.local_addr)) {
                stopKeyserver();
            }
        } else {
            if (Objects.equals(coord, this.local_addr)) {
                return;
            }
            try {
                becomeKeyserver();
            } catch (Throwable th) {
                this.log.error("failed becoming key server", th);
            }
        }
    }

    protected synchronized void becomeKeyserver() throws Exception {
        if (this.srv_sock == null || this.srv_sock.isClosed()) {
            this.log.debug("%s: becoming keyserver; creating server socket", this.local_addr);
            this.srv_sock = createServerSocket();
            this.srv_sock_handler = new Runner(getThreadFactory(), SSL_KEY_EXCHANGE.class.getSimpleName() + "-runner", this::accept, () -> {
                Util.close(this.srv_sock);
            });
            this.srv_sock_handler.start();
            this.log.debug("%s: SSL server socket listening on %s", this.local_addr, this.srv_sock.getLocalSocketAddress());
        }
    }

    protected synchronized void stopKeyserver() {
        if (this.srv_sock != null) {
            Util.close(this.srv_sock);
            this.srv_sock = null;
        }
        if (this.srv_sock_handler != null) {
            this.log.debug("%s: ceasing to be the keyserver; closing the server socket", this.local_addr);
            this.srv_sock_handler.stop();
            this.srv_sock_handler = null;
        }
    }

    protected SSLServerSocket createServerSocket() throws Exception {
        SSLServerSocketFactory serverSocketFactory = (this.server_ssl_ctx != null ? this.server_ssl_ctx : getContext()).getServerSocketFactory();
        for (int i = 0; i <= this.port_range; i++) {
            try {
                SSLServerSocket sSLServerSocket = (SSLServerSocket) serverSocketFactory.createServerSocket(this.port + i, 50, this.bind_addr);
                sSLServerSocket.setNeedClientAuth(this.require_client_authentication);
                return sSLServerSocket;
            } catch (Throwable th) {
            }
        }
        throw new IllegalStateException(String.format("found no valid port to bind to in range [%d-%d]", Integer.valueOf(this.port), Integer.valueOf(this.port + this.port_range)));
    }

    protected SSLSocket createSocketTo(Address address) throws Exception {
        SSLSocketFactory socketFactory = (this.client_ssl_ctx != null ? this.client_ssl_ctx : getContext()).getSocketFactory();
        if (address instanceof IpAddress) {
            return createSocketTo((IpAddress) address, socketFactory);
        }
        IpAddress ipAddress = (IpAddress) this.down_prot.down(new Event(87, address));
        for (int i = 0; i <= this.port_range; i++) {
            try {
                SSLSocket sSLSocket = (SSLSocket) socketFactory.createSocket(ipAddress.getIpAddress(), this.port + i);
                sSLSocket.setSoTimeout(this.socket_timeout);
                sSLSocket.setEnabledCipherSuites(sSLSocket.getSupportedCipherSuites());
                sSLSocket.startHandshake();
                SSLSession session = sSLSocket.getSession();
                this.log.debug("%s: created SSL connection to %s (%s); protocol: %s, cipher suite: %s", this.local_addr, address, sSLSocket.getRemoteSocketAddress(), session.getProtocol(), session.getCipherSuite());
                if (this.session_verifier != null) {
                    this.session_verifier.verify(session);
                }
                return sSLSocket;
            } catch (SecurityException e) {
                throw e;
            } catch (Throwable th) {
            }
        }
        throw new IllegalStateException(String.format("failed connecting to %s (port range [%d - %d])", ipAddress.getIpAddress(), Integer.valueOf(this.port), Integer.valueOf(this.port + this.port_range)));
    }

    protected SSLSocket createSocketTo(IpAddress ipAddress, SSLSocketFactory sSLSocketFactory) {
        try {
            SSLSocket sSLSocket = (SSLSocket) sSLSocketFactory.createSocket(ipAddress.getIpAddress(), ipAddress.getPort());
            sSLSocket.setSoTimeout(this.socket_timeout);
            sSLSocket.setEnabledCipherSuites(sSLSocket.getSupportedCipherSuites());
            sSLSocket.startHandshake();
            SSLSession session = sSLSocket.getSession();
            this.log.debug("%s: created SSL connection to %s (%s); protocol: %s, cipher suite: %s", this.local_addr, ipAddress, sSLSocket.getRemoteSocketAddress(), session.getProtocol(), session.getCipherSuite());
            if (this.session_verifier != null) {
                this.session_verifier.verify(session);
            }
            return sSLSocket;
        } catch (SecurityException e) {
            throw e;
        } catch (Throwable th) {
            throw new IllegalStateException(String.format("failed connecting to %s: %s", ipAddress, th.getMessage()));
        }
    }

    protected SSLContext getContext() throws Exception {
        if (this.client_ssl_ctx != null) {
            return this.client_ssl_ctx;
        }
        KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
        keyManagerFactory.init(this.key_store, this.keystore_password.toCharArray());
        KeyManager[] keyManagers = keyManagerFactory.getKeyManagers();
        TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
        trustManagerFactory.init(this.key_store);
        TrustManager[] trustManagers = trustManagerFactory.getTrustManagers();
        SSLContext sSLContext = SSLContext.getInstance("TLSv1");
        sSLContext.init(keyManagers, trustManagers, null);
        this.client_ssl_ctx = sSLContext;
        return sSLContext;
    }
}
