package com.atlassian.crowd.directory;

import com.atlassian.crowd.directory.ldap.LDAPPropertiesMapper;
import com.atlassian.crowd.directory.ldap.LdapSecureMode;
import com.atlassian.crowd.directory.ssl.CrowdTlsDirContextAuthenticationStrategy;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.dao.DataAccessResourceFailureException;
import org.springframework.ldap.core.ContextSource;
import org.springframework.ldap.core.support.LdapContextSource;
import org.springframework.ldap.pool2.DirContextType;
import org.springframework.ldap.pool2.factory.PoolConfig;
import org.springframework.ldap.pool2.factory.PooledContextSource;
import org.springframework.ldap.pool2.validation.DefaultDirContextValidator;

import javax.naming.Context;
import javax.naming.directory.DirContext;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;

import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.BLOCK_WHEN_EXHAUSTED;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.EVICTION_POLICY_CLASS;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.EVICTION_RUN_INTERVAL_MILLIS;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.FAIRNESS;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.JMX_ENABLE;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.JMX_NAME_PREFIX;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.LIFO;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.MAX_IDLE_PER_KEY;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.MAX_TOTAL;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.MAX_TOTAL_PER_KEY;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.MAX_WAIT;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.MIN_EVICTABLE_TIME_MILLIS;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.MIN_IDLE_PER_KEY;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.SOFT_MIN_EVICTABLE_TIME_MILLIS;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.TESTS_PER_EVICTION_RUN;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.TEST_ON_BORROW;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.TEST_ON_CREATE;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.TEST_ON_RETURN;
import static com.atlassian.crowd.common.properties.SpringLdapPoolProperties.TEST_WHILE_IDLE;

public class LdapContextSourceFactory {
    private static final Logger logger = LoggerFactory.getLogger(LdapContextSourceFactory.class);

    private final Supplier<LdapContextSource> ldapContextSourceSupplier;

    public LdapContextSourceFactory() {
        ldapContextSourceSupplier = LdapContextSource::new;
    }

    @VisibleForTesting
    LdapContextSourceFactory(final Supplier<LdapContextSource> ldapContextSourceSupplier) {
        this.ldapContextSourceSupplier = ldapContextSourceSupplier;
    }

    ContextSource createMinimalContextSource(final String username, final String password, final LDAPPropertiesMapper ldapPropertiesMapper, final Map<String, Object> envProperties) {
        final LdapContextSource contextSource = ldapContextSourceSupplier.get();
        contextSource.setUrl(ldapPropertiesMapper.getConnectionURL());
        contextSource.setUserDn(username);
        contextSource.setPassword(password);
        contextSource.setBaseEnvironmentProperties(envProperties);
        contextSource.setPooled(false);
        maybeApplyTls(ldapPropertiesMapper, contextSource);
        contextSource.afterPropertiesSet();

        return contextSource;
    }

    ContextSource createContextSource(final LDAPPropertiesMapper ldapPropertiesMapper, final Map<String, Object> envProperties, boolean useLegacyPooling) {
        final LdapContextSource contextSource = ldapContextSourceSupplier.get();

        //Attempt to look up and use the context factory if specified
        final String initialContextFactoryClassName = (String) envProperties.get(Context.INITIAL_CONTEXT_FACTORY);
        if (initialContextFactoryClassName != null) {
            try {
                contextSource.setContextFactory(Class.forName(initialContextFactoryClassName, false, SpringLDAPConnector.class.getClassLoader()));
            } catch (ClassNotFoundException e) {
                NoClassDefFoundError err = new NoClassDefFoundError(initialContextFactoryClassName);
                err.initCause(e);
                throw err;
            }
        }

        contextSource.setUrl(ldapPropertiesMapper.getConnectionURL());
        contextSource.setUserDn(ldapPropertiesMapper.getUsername());
        contextSource.setPassword(ldapPropertiesMapper.getPassword());
        contextSource.setBaseEnvironmentProperties(envProperties);
        boolean tlsApplied = maybeApplyTls(ldapPropertiesMapper, contextSource);
        contextSource.setPooled(useLegacyPooling && !tlsApplied);

        try {
            // we need to tell the context source to configure up our ldap server
            contextSource.afterPropertiesSet();
        } catch (Exception e) {
            logger.error("Failed to configure context source", e);
        }

        return contextSource;
    }

    /**
     * Checks if both LDAPPropertiesMappers have the same parameters meaningful for creating ContextSource.
     */
    boolean areConnectionPropertiesSame(LDAPPropertiesMapper left, LDAPPropertiesMapper right) {
        return Stream.<Function<LDAPPropertiesMapper, Object>>of(
                LDAPPropertiesMapper::getConnectionURL,
                LDAPPropertiesMapper::getUsername,
                LDAPPropertiesMapper::getPassword,
                LDAPPropertiesMapper::getSecureMode
        ).allMatch(getter -> Objects.equals(getter.apply(left), getter.apply(right)));
    }

    private boolean maybeApplyTls(LDAPPropertiesMapper ldapPropertiesMapper, LdapContextSource ldapContextSource) {
        if (ldapPropertiesMapper.getSecureMode() == LdapSecureMode.START_TLS) {
            ldapContextSource.setAuthenticationStrategy(new CrowdTlsDirContextAuthenticationStrategy());
            return true;
        }
        return false;
    }

    PooledContextSource createPooledContextSource(long directoryId, LDAPPropertiesMapper ldapPropertiesMapper, Map<String, Object> envProperties) {
        final ContextSource contextSource = createContextSource(ldapPropertiesMapper, envProperties, false);
        PoolConfig poolConfig = new PoolConfig();
        poolConfig.setBlockWhenExhausted(BLOCK_WHEN_EXHAUSTED.getValue());
        poolConfig.setEvictionPolicyClassName(EVICTION_POLICY_CLASS.getValue());
        poolConfig.setFairness(FAIRNESS.getValue());
        poolConfig.setJmxEnabled(JMX_ENABLE.getValue());
        poolConfig.setJmxNamePrefix(String.format("%s-directory-%d  ", JMX_NAME_PREFIX.getValue(), directoryId));
        poolConfig.setLifo(LIFO.getValue());
        poolConfig.setMaxIdlePerKey(MAX_IDLE_PER_KEY.getValue());
        poolConfig.setMaxTotal(MAX_TOTAL.getValue());
        poolConfig.setMaxTotalPerKey(MAX_TOTAL_PER_KEY.getValue());
        poolConfig.setMaxWaitMillis(MAX_WAIT.getValue());
        poolConfig.setMinEvictableIdleTimeMillis(MIN_EVICTABLE_TIME_MILLIS.getValue());
        poolConfig.setSoftMinEvictableIdleTimeMillis(SOFT_MIN_EVICTABLE_TIME_MILLIS.getValue());
        poolConfig.setNumTestsPerEvictionRun(TESTS_PER_EVICTION_RUN.getValue());
        poolConfig.setTimeBetweenEvictionRunsMillis(EVICTION_RUN_INTERVAL_MILLIS.getValue());
        poolConfig.setTestOnBorrow(TEST_ON_BORROW.getValue());
        poolConfig.setTestOnCreate(TEST_ON_CREATE.getValue());
        poolConfig.setTestOnReturn(TEST_ON_RETURN.getValue());
        poolConfig.setTestWhileIdle(TEST_WHILE_IDLE.getValue());
        poolConfig.setMinIdlePerKey(MIN_IDLE_PER_KEY.getValue());

        final PooledContextSource poolingContextSource = new CrowdPooledContextSource(poolConfig);
        poolingContextSource.setContextSource(contextSource);
        poolingContextSource.setDirContextValidator(new DefaultDirContextValidator());

        return poolingContextSource;
    }

    @VisibleForTesting
    static class CrowdPooledContextSource extends PooledContextSource {
        CrowdPooledContextSource(PoolConfig poolConfig) {
            super(poolConfig);
        }

        @Override
        protected DirContext getContext(DirContextType dirContextType) {
            try {
                return super.getContext(dirContextType);
            } catch (DataAccessResourceFailureException e) {
                // a lot of our code is dependent on exceptions thrown by context sources, so we need to unwrap them.
                if (e.getCause() instanceof RuntimeException) {
                    throw (RuntimeException) e.getCause();
                }
                throw e;
            }
        }
    }
}
