/*
 * Decompiled with CFR 0.152.
 */
package org.jgroups.protocols;

import java.io.Closeable;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.security.KeyStore;
import java.util.Map;
import java.util.Objects;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import org.jgroups.Address;
import org.jgroups.Event;
import org.jgroups.View;
import org.jgroups.annotations.LocalAddress;
import org.jgroups.annotations.MBean;
import org.jgroups.annotations.Property;
import org.jgroups.protocols.ASYM_ENCRYPT;
import org.jgroups.protocols.KeyExchange;
import org.jgroups.stack.IpAddress;
import org.jgroups.util.Runner;
import org.jgroups.util.Tuple;
import org.jgroups.util.Util;

@MBean(description="Key exchange protocol based on an SSL connection between secret key requester and provider (key server) to fetch a shared secret group key from the key server. That shared (symmetric) key is subsequently used to encrypt communication between cluster members")
public class SSL_KEY_EXCHANGE
extends KeyExchange {
    @LocalAddress
    @Property(description="Bind address for the server or client socket. The following special values are also recognized: GLOBAL, SITE_LOCAL, LINK_LOCAL and NON_LOOPBACK", systemProperty={"jgroups.bind_addr"})
    protected InetAddress bind_addr;
    @Property(description="The port at which the key server is listening. If the port is not available, the next port will be probed, up to port+port_range. Used by the key server (server) to create an SSLServerSocket and by clients to connect to the key server.")
    protected int port = 2157;
    @Property(description="The port range to probe")
    protected int port_range = 5;
    @Property(description="Location of the keystore")
    protected String keystore_name = "keystore.jks";
    @Property(description="The type of the keystore. Types are listed in http://docs.oracle.com/javase/8/docs/technotes/tools/unix/keytool.html")
    protected String keystore_type = "JKS";
    @Property(description="Password to access the keystore", exposeAsManagedAttribute=false)
    protected String keystore_password = "changeit";
    @Property(description="The type of secret key to be sent up the stack (converted from DH). Should be the same as the algorithm part of ASYM_ENCRYPT.sym_algorithm if ASYM_ENCRYPT is used")
    protected String secret_key_algorithm = "AES";
    @Property(description="If enabled, clients are authenticated as well (not just the server). Set to true to prevent rogue clients to fetch the secret group key (e.g. via man-in-the-middle attacks)")
    protected boolean require_client_authentication = true;
    @Property(description="Timeout (in ms) for a socket read. This applies for example to the initial SSL handshake, e.g. if the client connects to a non-JGroups service accidentally running on the same port")
    protected int socket_timeout = 1000;
    @Property(description="The fully qualified name of a class implementing SessionVerifier")
    protected String session_verifier_class;
    @Property(description="The argument to the session verifier")
    protected String session_verifier_arg;
    protected SSLContext client_ssl_ctx;
    protected SSLContext server_ssl_ctx;
    protected SSLServerSocket srv_sock;
    protected Runner srv_sock_handler;
    protected KeyStore key_store;
    protected View view;
    protected SessionVerifier session_verifier;

    public InetAddress getBindAddress() {
        return this.bind_addr;
    }

    public SSL_KEY_EXCHANGE setBindAddress(InetAddress a) {
        this.bind_addr = a;
        return this;
    }

    public int getPort() {
        return this.port;
    }

    public SSL_KEY_EXCHANGE setPort(int p) {
        this.port = p;
        return this;
    }

    public int getPortRange() {
        return this.port_range;
    }

    public SSL_KEY_EXCHANGE setPortRange(int r) {
        this.port_range = r;
        return this;
    }

    public String getKeystoreName() {
        return this.keystore_name;
    }

    public SSL_KEY_EXCHANGE setKeystoreName(String name) {
        this.keystore_name = name;
        return this;
    }

    public String getKeystoreType() {
        return this.keystore_type;
    }

    public SSL_KEY_EXCHANGE setKeystoreType(String type) {
        this.keystore_type = type;
        return this;
    }

    public String getKeystorePassword() {
        return this.keystore_password;
    }

    public SSL_KEY_EXCHANGE setKeystorePassword(String pwd2) {
        this.keystore_password = pwd2;
        return this;
    }

    public String getSecretKeyAlgorithm() {
        return this.secret_key_algorithm;
    }

    public SSL_KEY_EXCHANGE setSecretKeyAlgorithm(String a) {
        this.secret_key_algorithm = a;
        return this;
    }

    public boolean getRequireClientAuthentication() {
        return this.require_client_authentication;
    }

    public SSL_KEY_EXCHANGE setRequireClientAuthentication(boolean b) {
        this.require_client_authentication = b;
        return this;
    }

    public int getSocketTimeout() {
        return this.socket_timeout;
    }

    public SSL_KEY_EXCHANGE setSocketTimeout(int timeout) {
        this.socket_timeout = timeout;
        return this;
    }

    public String getSessionVerifierClass() {
        return this.session_verifier_class;
    }

    public SSL_KEY_EXCHANGE setSessionVerifierClass(String cl) {
        this.session_verifier_class = cl;
        return this;
    }

    public String getSessionVerifierArg() {
        return this.session_verifier_arg;
    }

    public SSL_KEY_EXCHANGE setSessionVerifierArg(String arg) {
        this.session_verifier_arg = arg;
        return this;
    }

    public KeyStore getKeystore() {
        return this.key_store;
    }

    public SSL_KEY_EXCHANGE setKeystore(KeyStore ks) {
        this.key_store = ks;
        return this;
    }

    public SessionVerifier getSessionVerifier() {
        return this.session_verifier;
    }

    public SSL_KEY_EXCHANGE setSessionVerifier(SessionVerifier s) {
        this.session_verifier = s;
        return this;
    }

    @Deprecated
    public SSLContext getSSLContext() {
        return this.client_ssl_ctx;
    }

    @Deprecated
    public SSL_KEY_EXCHANGE setSSLContext(SSLContext ssl_ctx) {
        this.client_ssl_ctx = ssl_ctx;
        return this;
    }

    public SSLContext getClientSSLContext() {
        return this.client_ssl_ctx;
    }

    public SSL_KEY_EXCHANGE setClientSSLContext(SSLContext client_ssl_ctx) {
        this.client_ssl_ctx = client_ssl_ctx;
        return this;
    }

    public SSLContext getServerSSLContext() {
        return this.server_ssl_ctx;
    }

    public SSL_KEY_EXCHANGE setServerSSLContext(SSLContext server_ssl_ctx) {
        this.server_ssl_ctx = server_ssl_ctx;
        return this;
    }

    @Override
    public Address getServerLocation() {
        return this.srv_sock == null ? null : new IpAddress(this.getTransport().getBindAddress(), this.srv_sock.getLocalPort());
    }

    @Override
    public void init() throws Exception {
        String sym_alg;
        super.init();
        if (this.port == 0) {
            throw new IllegalStateException("port must not be 0 or else clients will not be able to connect");
        }
        ASYM_ENCRYPT asym_encrypt = (ASYM_ENCRYPT)this.findProtocolAbove(ASYM_ENCRYPT.class);
        if (asym_encrypt != null && !Util.match(sym_alg = asym_encrypt.symKeyAlgorithm(), this.secret_key_algorithm)) {
            this.log.warn("%s: overriding %s=%s to %s from %s", "secret_key_algorithm", this.local_addr, this.secret_key_algorithm, sym_alg, ASYM_ENCRYPT.class.getSimpleName());
            this.secret_key_algorithm = sym_alg;
        }
        if (this.key_store == null && (this.client_ssl_ctx == null || this.server_ssl_ctx == null)) {
            InputStream input;
            this.key_store = KeyStore.getInstance(this.keystore_type != null ? this.keystore_type : KeyStore.getDefaultType());
            try {
                input = new FileInputStream(this.keystore_name);
            }
            catch (FileNotFoundException not_found) {
                input = Util.getResourceAsStream(this.keystore_name, this.getClass());
            }
            if (input == null) {
                throw new FileNotFoundException(this.keystore_name);
            }
            this.key_store.load(input, this.keystore_password.toCharArray());
        }
        if (this.session_verifier == null && this.session_verifier_class != null) {
            Class verifier_class = Util.loadClass(this.session_verifier_class, this.getClass());
            this.session_verifier = (SessionVerifier)verifier_class.getDeclaredConstructor(new Class[0]).newInstance(new Object[0]);
            if (this.session_verifier_arg != null) {
                this.session_verifier.init(this.session_verifier_arg);
            }
        }
    }

    @Override
    public void start() throws Exception {
        super.start();
    }

    @Override
    public void stop() {
        super.stop();
        this.stopKeyserver();
    }

    @Override
    public void destroy() {
        super.destroy();
    }

    @Override
    public Object up(Event evt) {
        if (evt.getType() == 56) {
            if (this.bind_addr == null) {
                Map config = (Map)evt.getArg();
                this.bind_addr = (InetAddress)config.get("bind_addr");
            }
            return this.up_prot.up(evt);
        }
        return super.up(evt);
    }

    @Override
    public void fetchSecretKeyFrom(Address target) throws Exception {
        try (SSLSocket sock = this.createSocketTo(target);){
            DataInputStream in = new DataInputStream(sock.getInputStream());
            OutputStream out = sock.getOutputStream();
            out.write(Type.SECRET_KEY_REQ.ordinal());
            out.flush();
            byte ordinal = in.readByte();
            Type rsp = Type.values()[ordinal];
            if (rsp != Type.SECRET_KEY_RSP) {
                throw new IllegalStateException(String.format("expected response of %s but got type=%d", new Object[]{Type.SECRET_KEY_RSP, ordinal}));
            }
            int version_len = in.readInt();
            byte[] version = new byte[version_len];
            in.readFully(version);
            int secret_key_len = in.readInt();
            byte[] secret_key = new byte[secret_key_len];
            in.readFully(secret_key);
            SecretKeySpec sk = new SecretKeySpec(secret_key, this.secret_key_algorithm);
            Tuple<SecretKeySpec, byte[]> tuple = new Tuple<SecretKeySpec, byte[]>(sk, version);
            this.log.debug("%s: sending up secret key (version: %s)", this.local_addr, Util.byteArrayToHexString(version));
            this.up_prot.up(new Event(112, tuple));
        }
    }

    protected void accept() {
        try (SSLSocket client_sock = (SSLSocket)this.srv_sock.accept();){
            client_sock.setEnabledCipherSuites(client_sock.getSupportedCipherSuites());
            client_sock.startHandshake();
            SSLSession sslSession = client_sock.getSession();
            this.log.debug("%s: accepted SSL connection from %s; protocol: %s, cipher suite: %s", this.local_addr, client_sock.getRemoteSocketAddress(), sslSession.getProtocol(), sslSession.getCipherSuite());
            if (this.session_verifier != null) {
                this.session_verifier.verify(sslSession);
            }
            InputStream in = client_sock.getInputStream();
            DataOutputStream out = new DataOutputStream(client_sock.getOutputStream());
            byte ordinal = (byte)in.read();
            Type req = Type.values()[ordinal];
            if (req != Type.SECRET_KEY_REQ) {
                throw new IllegalStateException(String.format("expected request of %s but got type=%d", new Object[]{Type.SECRET_KEY_REQ, ordinal}));
            }
            Tuple tuple = (Tuple)this.up_prot.up(new Event(111));
            if (tuple == null) {
                return;
            }
            byte[] version = (byte[])tuple.getVal2();
            byte[] secret_key = ((SecretKey)tuple.getVal1()).getEncoded();
            out.write(Type.SECRET_KEY_RSP.ordinal());
            out.writeInt(version.length);
            out.write(version, 0, version.length);
            out.writeInt(secret_key.length);
            out.write(secret_key);
        }
        catch (Throwable t) {
            this.log.trace("%s: failure handling client socket: %s", this.local_addr, t.getMessage());
        }
    }

    @Override
    protected void handleView(View view) {
        Address old_coord = this.view != null ? this.view.getCoord() : null;
        this.view = view;
        if (Objects.equals(view.getCoord(), this.local_addr)) {
            if (!Objects.equals(old_coord, this.local_addr)) {
                try {
                    this.becomeKeyserver();
                }
                catch (Throwable e) {
                    this.log.error("failed becoming key server", e);
                }
            }
        } else if (Objects.equals(old_coord, this.local_addr)) {
            this.stopKeyserver();
        }
    }

    protected synchronized void becomeKeyserver() throws Exception {
        if (this.srv_sock == null || this.srv_sock.isClosed()) {
            this.log.debug("%s: becoming keyserver; creating server socket", this.local_addr);
            this.srv_sock = this.createServerSocket();
            this.srv_sock_handler = new Runner(this.getThreadFactory(), SSL_KEY_EXCHANGE.class.getSimpleName() + "-runner", this::accept, () -> Util.close((Closeable)this.srv_sock));
            this.srv_sock_handler.start();
            this.log.debug("%s: SSL server socket listening on %s", this.local_addr, this.srv_sock.getLocalSocketAddress());
        }
    }

    protected synchronized void stopKeyserver() {
        if (this.srv_sock != null) {
            Util.close((Closeable)this.srv_sock);
            this.srv_sock = null;
        }
        if (this.srv_sock_handler != null) {
            this.log.debug("%s: ceasing to be the keyserver; closing the server socket", this.local_addr);
            this.srv_sock_handler.stop();
            this.srv_sock_handler = null;
        }
    }

    protected SSLServerSocket createServerSocket() throws Exception {
        SSLContext ctx = this.server_ssl_ctx != null ? this.server_ssl_ctx : this.getContext();
        SSLServerSocketFactory sslServerSocketFactory = ctx.getServerSocketFactory();
        for (int i = 0; i <= this.port_range; ++i) {
            try {
                SSLServerSocket sslServerSocket = (SSLServerSocket)sslServerSocketFactory.createServerSocket(this.port + i, 50, this.bind_addr);
                sslServerSocket.setNeedClientAuth(this.require_client_authentication);
                return sslServerSocket;
            }
            catch (Throwable throwable) {
                continue;
            }
        }
        throw new IllegalStateException(String.format("found no valid port to bind to in range [%d-%d]", this.port, this.port + this.port_range));
    }

    protected SSLSocket createSocketTo(Address target) throws Exception {
        SSLContext ctx = this.client_ssl_ctx != null ? this.client_ssl_ctx : this.getContext();
        SSLSocketFactory sslSocketFactory = ctx.getSocketFactory();
        if (target instanceof IpAddress) {
            return this.createSocketTo((IpAddress)target, sslSocketFactory);
        }
        IpAddress dest = (IpAddress)this.down_prot.down(new Event(87, target));
        for (int i = 0; i <= this.port_range; ++i) {
            try {
                SSLSocket sock = (SSLSocket)sslSocketFactory.createSocket(dest.getIpAddress(), this.port + i);
                sock.setSoTimeout(this.socket_timeout);
                sock.setEnabledCipherSuites(sock.getSupportedCipherSuites());
                sock.startHandshake();
                SSLSession sslSession = sock.getSession();
                this.log.debug("%s: created SSL connection to %s (%s); protocol: %s, cipher suite: %s", this.local_addr, target, sock.getRemoteSocketAddress(), sslSession.getProtocol(), sslSession.getCipherSuite());
                if (this.session_verifier != null) {
                    this.session_verifier.verify(sslSession);
                }
                return sock;
            }
            catch (SecurityException sec_ex) {
                throw sec_ex;
            }
            catch (Throwable throwable) {
                continue;
            }
        }
        throw new IllegalStateException(String.format("failed connecting to %s (port range [%d - %d])", dest.getIpAddress(), this.port, this.port + this.port_range));
    }

    protected SSLSocket createSocketTo(IpAddress dest, SSLSocketFactory sslSocketFactory) {
        try {
            SSLSocket sock = (SSLSocket)sslSocketFactory.createSocket(dest.getIpAddress(), dest.getPort());
            sock.setSoTimeout(this.socket_timeout);
            sock.setEnabledCipherSuites(sock.getSupportedCipherSuites());
            sock.startHandshake();
            SSLSession sslSession = sock.getSession();
            this.log.debug("%s: created SSL connection to %s (%s); protocol: %s, cipher suite: %s", this.local_addr, dest, sock.getRemoteSocketAddress(), sslSession.getProtocol(), sslSession.getCipherSuite());
            if (this.session_verifier != null) {
                this.session_verifier.verify(sslSession);
            }
            return sock;
        }
        catch (SecurityException sec_ex) {
            throw sec_ex;
        }
        catch (Throwable t) {
            throw new IllegalStateException(String.format("failed connecting to %s: %s", dest, t.getMessage()));
        }
    }

    protected SSLContext getContext() throws Exception {
        if (this.client_ssl_ctx != null) {
            return this.client_ssl_ctx;
        }
        KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
        keyManagerFactory.init(this.key_store, this.keystore_password.toCharArray());
        KeyManager[] km = keyManagerFactory.getKeyManagers();
        TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
        trustManagerFactory.init(this.key_store);
        TrustManager[] tm = trustManagerFactory.getTrustManagers();
        SSLContext sslContext = SSLContext.getInstance("TLSv1");
        sslContext.init(km, tm, null);
        this.client_ssl_ctx = sslContext;
        return this.client_ssl_ctx;
    }

    public static interface SessionVerifier {
        public void init(String var1);

        public void verify(SSLSession var1) throws SecurityException;
    }

    protected static enum Type {
        SECRET_KEY_REQ,
        SECRET_KEY_RSP;

    }
}

