package com.atlassian.crowd.directory.ldap;

import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;

import javax.naming.Name;
import javax.naming.NamingEnumeration;
import javax.naming.directory.Attributes;
import javax.naming.directory.DirContext;
import javax.naming.directory.ModificationItem;
import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.LdapName;

import com.atlassian.crowd.directory.LimitedNamingEnumeration;
import com.atlassian.crowd.directory.ldap.mapper.AttributeToContextCallbackHandler;
import com.atlassian.crowd.directory.ldap.mapper.ContextMapperWithRequiredAttributes;
import com.atlassian.crowd.directory.ldap.mapper.LookupCallbackHandler;
import com.atlassian.crowd.search.query.entity.EntityQuery;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import com.google.common.collect.Iterables;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ldap.core.AttributesMapper;
import org.springframework.ldap.core.ContextMapper;
import org.springframework.ldap.core.ContextMapperCallbackHandler;
import org.springframework.ldap.core.DirContextProcessor;
import org.springframework.ldap.core.LdapTemplate;
import org.springframework.ldap.core.NameClassPairCallbackHandler;
import org.springframework.ldap.core.SearchExecutor;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;

/**
 * <p>
 * Wrap an {@link LdapTemplate} and perform all operations with the context
 * ClassLoader set to this class's ClassLoader.
 * <code>com.sun.naming.internal.NamingManager</code> uses the context
 * ClassLoader so, without this wrapper, calls that originate from plugins and
 * end up using LDAP will fail when they can't see the Spring LDAP
 * implementation classes.
 * </p>
 * <p>
 * Also logs how long the ldap query took, at {@code DEBUG} level for all queries, or at
 * {@code INFO} level if the query's duration exceeds a threshold (default 1 second),
 * settable by {@code com.atlassian.crowd.ldap.log.wait.threshold}.
 * </p>
 * <p>
 * This class is the blessed way to interact with LDAP. LdapTemplate should <strong>not</strong> be used directly,
 * as this could open us up to an LDAP object injection vulnerability (see CWD-4754). This class calls very specific
 * methods of LdapTemplate in order to avoid manipulation of the {@link SearchControls}. LdapTemplate can set the
 * <code>returnObj</code> flag in the SearchControls to true before executing the search, opening us up to that very
 * vulnerability. Care should be taken when upgrading Spring LDAP to ensure that the search methods called still behave
 * the same.
 * </p>
 * <p>
 * As a safety net around providing SearchControls with the <code>returningObj</code> flag set to false, the search methods
 * of this class will throw an {@link IllegalArgumentException} if the search controls provided had the said flag set to
 * true.
 * </p>
 */
public class SpringLdapTemplateWrapper {
    private static final String TIMED_LOG_THRESHOLD_MILLIS = "com.atlassian.crowd.ldap.log.wait.threshold";
    private static final long DEFAULT_TIMED_LOG_THRESHOLD_MILLIS = 1000L;
    private static final Logger logger = LoggerFactory.getLogger(SpringLdapTemplateWrapper.class);
    private final LdapTemplate template;
    private final long logThreshold;

    public SpringLdapTemplateWrapper(LdapTemplate template) {
        this.template = template;
        this.logThreshold = getLogThreshold();
    }

    private static long getLogThreshold() {
        final String thresholdString = System.getProperty(TIMED_LOG_THRESHOLD_MILLIS);
        if (thresholdString == null) {
            return DEFAULT_TIMED_LOG_THRESHOLD_MILLIS;
        }
        try {
            return Long.parseLong(thresholdString);
        } catch (NumberFormatException e) {
            logger.warn("Could not parse 'com.atlassian.crowd.ldap.log.wait.threshold'. Using default of 1s.");
            return DEFAULT_TIMED_LOG_THRESHOLD_MILLIS;
        }
    }

    static <T> T invokeWithContextClassLoader(CallableWithoutCheckedException<T> runnable) {
        Thread current = Thread.currentThread();
        ClassLoader orig = current.getContextClassLoader();

        try {
            ClassLoader classLoaderForThisClass = SpringLdapTemplateWrapper.class.getClassLoader();
            current.setContextClassLoader(classLoaderForThisClass);

            return runnable.call();
        } finally {
            current.setContextClassLoader(orig);
        }
    }

    @VisibleForTesting
    static abstract class TimedCallable<T> implements CallableWithoutCheckedException<T> {
        private final Logger log;
        private final Stopwatch watch;
        private final long thresholdMillis;

        public TimedCallable(final long timedThreshold) {
            this(Stopwatch.createUnstarted(), logger, timedThreshold);
        }

        public TimedCallable(Stopwatch watch, Logger log, long thresholdMillis) {
            this.watch = watch;
            this.log = log;
            this.thresholdMillis = thresholdMillis;
        }

        public abstract T timedCall();

        public abstract String message();

        @Override
        public final T call() {
            watch.start();
            try {
                return timedCall();
            } finally {
                watch.stop();
                if (watch.elapsed(TimeUnit.MILLISECONDS) > thresholdMillis) {
                    log.info("Timed call for {} took {}ms", message(), watch.elapsed(TimeUnit.MILLISECONDS));
                } else if (log.isDebugEnabled()) {
                    log.debug("Timed call for {} took {}ms", message(), watch.elapsed(TimeUnit.MILLISECONDS));
                }
            }
        }
    }

    public List search(final Name base, final String filter, final SearchControls controls, final ContextMapper mapper) {
        Preconditions.checkArgument(!controls.getReturningObjFlag());
        return invokeWithContextClassLoader(new TimedCallable<List>(logThreshold) {
            public List timedCall() {
                SearchExecutor se = (ctx) -> ctx.search(base, filter, controls);
                final AttributeToContextCallbackHandler handler = new AttributeToContextCallbackHandler<>(mapper);
                template.search(se, handler, new LdapTemplate.NullDirContextProcessor());
                return handler.getList();
            }

            @Override
            public String message() {
                return "search on " + base;
            }
        });
    }

    public List search(final Name base, final String filter, final SearchControls controls, final ContextMapper mapper,
                       final DirContextProcessor processor) {
        Preconditions.checkArgument(!controls.getReturningObjFlag());
        final AttributeToContextCallbackHandler handler = new AttributeToContextCallbackHandler<>(mapper);
        search(base, filter, controls, handler, processor);
        return handler.getList();
    }

    public Object lookup(final Name dn) {
        return invokeWithContextClassLoader(new TimedCallable<Object>(logThreshold) {
            public Object timedCall() {
                SearchExecutor se = (ctx) -> {
                    final SearchControls searchControls = new SearchControls();
                    searchControls.setSearchScope(SearchControls.OBJECT_SCOPE);
                    searchControls.setReturningAttributes(null);
                    searchControls.setReturningObjFlag(false);
                    return ctx.search(dn, "(objectClass=*)", searchControls);
                };
                final LookupCallbackHandler<Object> handler = new LookupCallbackHandler<>();
                template.search(se, handler);
                return Iterables.getFirst(handler.getList(), null);
            }

            @Override
            public String message() {
                return "lookup on " + dn;
            }
        });
    }

    public void search(final Name base, final String filter, final SearchControls controls,
                       final AttributeToContextCallbackHandler handler, final DirContextProcessor processor) {
        Preconditions.checkArgument(!controls.getReturningObjFlag());
        invokeWithContextClassLoader(new TimedCallable<Object>(logThreshold) {
            public Void timedCall() {
                SearchExecutor se = (ctx) -> ctx.search(base, filter, controls);
                template.search(se, handler, processor);
                return null;
            }

            @Override
            public String message() {
                return "search with handler on " + base;
            }
        });
    }

    public void unbind(final Name dn) {
        invokeWithContextClassLoader(new TimedCallable<Object>(logThreshold) {
            public Void timedCall() {
                template.unbind(dn);
                return null;
            }

            @Override
            public String message() {
                return "unbind on " + dn;
            }
        });
    }

    public void bind(final Name dn, final Object obj, final Attributes attributes) {
        invokeWithContextClassLoader(new TimedCallable<Object>(logThreshold) {
            public Void timedCall() {
                template.bind(dn, obj, attributes);
                return null;
            }

            @Override
            public String message() {
                return "bind on " + dn;
            }
        });
    }

    public void modifyAttributes(final Name dn, final ModificationItem[] mods) {
        invokeWithContextClassLoader(new TimedCallable<Object>(logThreshold) {
            public Void timedCall() {
                template.modifyAttributes(dn, mods);
                return null;
            }

            @Override
            public String message() {
                return "modify attributes on " + dn;
            }
        });
    }

    public void lookup(final LdapName dn, final String[] attributes, final AttributesMapper mapper) {
        invokeWithContextClassLoader(new TimedCallable<Object>(logThreshold) {
            public Void timedCall() {
                template.lookup(dn, attributes, mapper);
                return null;
            }

            @Override
            public String message() {
                return "lookup on " + dn;
            }
        });
    }

    public <T> T lookup(final LdapName dn, final ContextMapperWithRequiredAttributes<T> mapper) {
        Set<String> attrSet = mapper.getRequiredLdapAttributes();

        final String[] attributes = attrSet.toArray(new String[attrSet.size()]);

        return invokeWithContextClassLoader(new TimedCallable<T>(logThreshold) {
            @SuppressWarnings("unchecked")
            public T timedCall() {
                return (T) template.lookup(dn, attributes, mapper);
            }

            @Override
            public String message() {
                return "lookup with mapper on " + dn;
            }
        });
    }

    public void setIgnorePartialResultException(boolean ignore) {
        /* This doesn't load classes */
        template.setIgnorePartialResultException(ignore);
    }

    private void search(final SearchExecutor se, final NameClassPairCallbackHandler handler, final DirContextProcessor processor) {
        invokeWithContextClassLoader(new TimedCallable<Object>(logThreshold) {
            public Void timedCall() {
                template.search(se, handler, processor);
                return null;
            }

            @Override
            public String message() {
                return "search using searchexecutor " + se;
            }
        });
    }

    static interface CallableWithoutCheckedException<T> {
        T call();
    }

    public List searchWithLimitedResults(final Name baseDN, final String filter, final SearchControls searchControls,
                                         ContextMapper contextMapper, DirContextProcessor processor, final int limit) {
        Preconditions.checkArgument(!searchControls.getReturningObjFlag());
        SearchExecutor se = new SearchExecutor() {
            @SuppressFBWarnings(value = "LDAP_INJECTION", justification = "No user input, filter is encoded in all calls")
            public NamingEnumeration<SearchResult> executeSearch(DirContext ctx) throws javax.naming.NamingException {
                NamingEnumeration<SearchResult> ne = ctx.search(baseDN, filter, searchControls);

                if (limit != EntityQuery.ALL_RESULTS) {
                    return new LimitedNamingEnumeration<SearchResult>(ne, limit);
                } else {
                    return ne;
                }
            }
        };

        ContextMapperCallbackHandler handler = new AttributeToContextCallbackHandler(contextMapper);

        search(se, handler, processor);

        return handler.getList();
    }
}
