package com.atlassian.crowd.directory.ssl;

import java.io.IOException;
import java.net.InetAddress;
import java.net.Socket;
import java.net.UnknownHostException;
import java.security.NoSuchAlgorithmException;
import java.util.Comparator;

import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * <p>This SocketFactory needs to call
 * {@link SSLParameters#setEndpointIdentificationAlgorithm(String)} to enable LDAPS hostname
 * verification.</p>
 */
public class LdapHostnameVerificationSSLSocketFactory extends SocketFactory implements Comparator<String> {
    private static final Logger log = LoggerFactory.getLogger(LdapHostnameVerificationSSLSocketFactory.class);

    private final SSLSocketFactory sf;

    private LdapHostnameVerificationSSLSocketFactory() throws NoSuchAlgorithmException {
        this.sf = SSLContext.getDefault().getSocketFactory();
    }

    public static synchronized SocketFactory getDefault() {
        log.debug("Name checking SSLSocketFactory created");
        try {
            return new LdapHostnameVerificationSSLSocketFactory();
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Accept a {@link SSLSocket}
     * and invoke {@link SSLParameters#setEndpointIdentificationAlgorithm(String)} with <code>"LDAPS"</code>.
     */
    static void makeUseLdapVerification(Socket s) {
        if (!(s instanceof SSLSocket)) {
            throw new IllegalArgumentException("Unexpected SSLSocket implementation: " + s.getClass().getName());
        }

        SSLSocket ssls = (SSLSocket) s;

        SSLParameters param = ssls.getSSLParameters();
        param.setEndpointIdentificationAlgorithm("LDAPS");
        ssls.setSSLParameters(param);
    }

    @Override
    public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort)
            throws IOException {
        log.warn("Creating socket to " + address);
        Socket s = sf.createSocket(address, port, localAddress, localPort);
        makeUseLdapVerification(s);
        return s;
    }

    @Override
    public Socket createSocket(InetAddress host, int port) throws IOException {
        log.debug("Creating socket to " + host);
        Socket s = sf.createSocket(host, port);
        makeUseLdapVerification(s);
        return s;
    }

    @Override
    public Socket createSocket(String host, int port) throws IOException, UnknownHostException {
        log.debug("Creating socket to " + host);
        Socket s = sf.createSocket(host, port);
        makeUseLdapVerification(s);
        return s;
    }

    @Override
    public Socket createSocket(String host, int port, InetAddress localHost, int localPort) throws IOException,
            UnknownHostException {
        log.debug("Creating socket to " + host);
        Socket s = sf.createSocket(host, port, localHost, localPort);
        makeUseLdapVerification(s);
        return s;
    }

    @Override
    public Socket createSocket() throws IOException {
        log.debug("Creating disconnected socket");
        Socket s = sf.createSocket();
        makeUseLdapVerification(s);
        return s;
    }

    /**
     * As per <a href="http://docs.oracle.com/javase/6/docs/technotes/guides/jndi/jndi-ldap.html#pooling">Pooling Custom Socket Factory Connections</a>,
     * support comparing socket factories.
     * We need to compare <code>String</code> instances  (the factory class name) rather than <code>SocketFactory</code> due to a bug in Java; see
     * <a href="http://stackoverflow.com/questions/23898970/pooling-ldap-connections-with-custom-socket-factory">Pooling LDAP connections with custom socket factory - Stack Overflow</a>.
     */
    @Override
    public int compare(String socketFactory1, String socketFactory2) {
        return socketFactory1.compareTo(socketFactory2);
    }
}
