package com.atlassian.crowd.directory.ldap;

import com.atlassian.crowd.common.properties.DurationSystemProperty;
import com.atlassian.crowd.directory.LimitedNamingEnumeration;
import com.atlassian.crowd.directory.ldap.mapper.AttributeToContextCallbackHandler;
import com.atlassian.crowd.directory.ldap.mapper.ContextMapperWithRequiredAttributes;
import com.atlassian.crowd.directory.ldap.mapper.LookupCallbackHandler;
import com.atlassian.crowd.directory.ldap.monitoring.ExecutionInfoNameClassPairCallbackHandler;
import com.atlassian.crowd.directory.ldap.monitoring.TimedSupplier;
import com.atlassian.crowd.search.query.entity.EntityQuery;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ldap.core.AttributesMapper;
import org.springframework.ldap.core.ContextMapper;
import org.springframework.ldap.core.ContextMapperCallbackHandler;
import org.springframework.ldap.core.DirContextProcessor;
import org.springframework.ldap.core.LdapTemplate;
import org.springframework.ldap.core.NameClassPairCallbackHandler;
import org.springframework.ldap.core.SearchExecutor;

import javax.naming.Name;
import javax.naming.NamingEnumeration;
import javax.naming.directory.Attributes;
import javax.naming.directory.DirContext;
import javax.naming.directory.ModificationItem;
import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.LdapName;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;

/**
 * <p>
 * Wrap an {@link LdapTemplate} and perform all operations with the context
 * ClassLoader set to this class's ClassLoader.
 * <code>com.sun.naming.internal.NamingManager</code> uses the context
 * ClassLoader so, without this wrapper, calls that originate from plugins and
 * end up using LDAP will fail when they can't see the Spring LDAP
 * implementation classes.
 * </p>
 * <p>
 * Also logs how long the ldap query took, at {@code DEBUG} level for all queries, or at
 * {@code INFO} level if the query's duration exceeds a threshold (default 1 second),
 * settable by {@code com.atlassian.crowd.ldap.log.wait.threshold}.
 * </p>
 * <p>
 * This class is the blessed way to interact with LDAP. LdapTemplate should <strong>not</strong> be used directly,
 * as this could open us up to an LDAP object injection vulnerability (see CWD-4754). This class calls very specific
 * methods of LdapTemplate in order to avoid manipulation of the {@link SearchControls}. LdapTemplate can set the
 * <code>returnObj</code> flag in the SearchControls to true before executing the search, opening us up to that very
 * vulnerability. Care should be taken when upgrading Spring LDAP to ensure that the search methods called still behave
 * the same.
 * </p>
 * <p>
 * As a safety net around providing SearchControls with the <code>returningObj</code> flag set to false, the search methods
 * of this class will throw an {@link IllegalArgumentException} if the search controls provided had the said flag set to
 * true.
 * </p>
 */
public class SpringLdapTemplateWrapper {
    private static final DurationSystemProperty TIMED_LOG_THRESHOLD = new DurationSystemProperty(
            "com.atlassian.crowd.ldap.log.wait.threshold", ChronoUnit.MILLIS, 1000L);
    private static final Logger logger = LoggerFactory.getLogger(SpringLdapTemplateWrapper.class);
    private final LdapTemplate template;
    private final long logThreshold;

    public SpringLdapTemplateWrapper(LdapTemplate template) {
        this.template = template;
        this.logThreshold = TIMED_LOG_THRESHOLD.getValue().toMillis();
    }

    static <T> T invokeWithContextClassLoader(Supplier<T> supplier) {
        Thread current = Thread.currentThread();
        ClassLoader orig = current.getContextClassLoader();

        try {
            ClassLoader classLoaderForThisClass = SpringLdapTemplateWrapper.class.getClassLoader();
            current.setContextClassLoader(classLoaderForThisClass);

            return supplier.get();
        } finally {
            current.setContextClassLoader(orig);
        }
    }

    public List search(final Name base, final String filter, final SearchControls controls, final ContextMapper mapper) {
        Preconditions.checkArgument(!controls.getReturningObjFlag());
        final String operationDescription = "search on " + base + " with filter: " + filter;
        return invokeWithContextClassLoader(new TimedSupplier<List>(operationDescription, logThreshold) {
            public List timedGet() {
                final SearchExecutor se = (ctx) -> ctx.search(base, filter, controls);
                final ContextMapperCallbackHandler handler = new AttributeToContextCallbackHandler<>(mapper);
                final ExecutionInfoNameClassPairCallbackHandler wrappedHandler = wrapHandler(handler);

                template.search(se, wrappedHandler, new LdapTemplate.NullDirContextProcessor());
                wrappedHandler.logResultCount();
                return handler.getList();
            }
        });
    }

    public List search(final Name base, final String filter, final SearchControls controls, final ContextMapper mapper,
                       final DirContextProcessor processor) {
        Preconditions.checkArgument(!controls.getReturningObjFlag());
        final AttributeToContextCallbackHandler handler = new AttributeToContextCallbackHandler<>(mapper);
        search(base, filter, controls, handler, processor);
        return handler.getList();
    }

    public Object lookup(final Name dn) {
        final String operationDescription = "lookup on " + dn;
        return invokeWithContextClassLoader(new TimedSupplier<Object>(operationDescription, logThreshold) {
            public Object timedGet() {
                SearchExecutor se = (ctx) -> {
                    final SearchControls searchControls = new SearchControls();
                    searchControls.setSearchScope(SearchControls.OBJECT_SCOPE);
                    searchControls.setReturningAttributes(null);
                    searchControls.setReturningObjFlag(false);
                    return ctx.search(dn, "(objectClass=*)", searchControls);
                };
                final LookupCallbackHandler handler = new LookupCallbackHandler<>();
                final ExecutionInfoNameClassPairCallbackHandler wrappedHandler = wrapHandler(handler);
                template.search(se, wrappedHandler);
                wrappedHandler.logResultCount();
                return Iterables.getFirst(handler.getList(), null);
            }
        });
    }

    public void search(final Name base, final String filter, final SearchControls controls,
                       final AttributeToContextCallbackHandler handler, final DirContextProcessor processor) {
        Preconditions.checkArgument(!controls.getReturningObjFlag());
        final String operationDescription = "search with handler on baseDN: " + base + ", filter: " + filter;
        invokeWithContextClassLoader(new TimedSupplier<Object>(operationDescription, logThreshold) {
            public Void timedGet() {
                SearchExecutor se = (ctx) -> ctx.search(base, filter, controls);
                final ExecutionInfoNameClassPairCallbackHandler wrappedHandler = wrapHandler(handler);
                template.search(se, wrappedHandler, processor);
                wrappedHandler.logResultCount();
                return null;
            }
        });
    }

    public void unbind(final Name dn) {
        invokeWithContextClassLoader(new TimedSupplier<Object>("unbind on " + dn, logThreshold) {
            public Void timedGet() {
                template.unbind(dn);
                return null;
            }
        });
    }

    public void bind(final Name dn, final Object obj, final Attributes attributes) {
        invokeWithContextClassLoader(new TimedSupplier<Object>("bind on " + dn, logThreshold) {
            public Void timedGet() {
                template.bind(dn, obj, attributes);
                return null;
            }
        });
    }

    public void rename(final String oldDn, final String newDn) {
        invokeWithContextClassLoader(new TimedSupplier<Object>("rename " + oldDn + " -> " + newDn, logThreshold) {
            public Void timedGet() {
                template.rename(oldDn, newDn);
                return null;
            }
        });
    }

    public void modifyAttributes(final Name dn, final ModificationItem[] mods) {
        invokeWithContextClassLoader(new TimedSupplier<Object>("modify attributes on " + dn, logThreshold) {
            public Void timedGet() {
                template.modifyAttributes(dn, mods);
                return null;
            }
        });
    }

    public void lookup(final LdapName dn, final String[] attributes, final AttributesMapper mapper) {
        invokeWithContextClassLoader(new TimedSupplier<Object>("lookup on " + dn, logThreshold) {
            public Void timedGet() {
                final Object result = template.lookup(dn, attributes, mapper);
                logger.trace("Lookup result: [{}]", result);
                return null;
            }
        });
    }

    public <T> T lookup(final LdapName dn, final ContextMapperWithRequiredAttributes<T> mapper) {
        final Set<String> attrSet = mapper.getRequiredLdapAttributes();
        final String[] attributes = attrSet.toArray(new String[attrSet.size()]);

        return invokeWithContextClassLoader(new TimedSupplier<T>("lookup with mapper on " + dn, logThreshold) {
            @SuppressWarnings("unchecked")
            public T timedGet() {
                final T result = (T) template.lookup(dn, attributes, mapper);
                logger.trace("Lookup result: [{}]", result);
                return result;

            }
        });
    }

    public void setIgnorePartialResultException(boolean ignore) {
        /* This doesn't load classes */
        template.setIgnorePartialResultException(ignore);
    }

    private void search(final SearchExecutor se, final NameClassPairCallbackHandler handler, final DirContextProcessor processor) {
        final String operationDescription = "search using searchexecutor " + se;
        invokeWithContextClassLoader(new TimedSupplier<Object>(operationDescription, logThreshold) {
            public Void timedGet() {
                final ExecutionInfoNameClassPairCallbackHandler wrappedHandler = wrapHandler(handler);
                template.search(se, wrappedHandler, processor);
                wrappedHandler.logResultCount();
                return null;
            }
        });
    }

    private ExecutionInfoNameClassPairCallbackHandler wrapHandler(final NameClassPairCallbackHandler handler) {
        return new ExecutionInfoNameClassPairCallbackHandler<>(handler);
    }

    public List searchWithLimitedResults(final Name baseDN, final String filter, final SearchControls searchControls,
                                         ContextMapper contextMapper, DirContextProcessor processor, final int limit) {
        Preconditions.checkArgument(!searchControls.getReturningObjFlag());
        SearchExecutor se = new SearchExecutor() {
            @SuppressFBWarnings(value = "LDAP_INJECTION", justification = "No user input, filter is encoded in all calls")
            public NamingEnumeration<SearchResult> executeSearch(DirContext ctx) throws javax.naming.NamingException {
                NamingEnumeration<SearchResult> ne = ctx.search(baseDN, filter, searchControls);

                if (limit != EntityQuery.ALL_RESULTS) {
                    return new LimitedNamingEnumeration<SearchResult>(ne, limit);
                } else {
                    return ne;
                }
            }

            @Override
            public String toString() {
                return "baseDN: " + baseDN + ", filter: " + filter;
            }
        };

        ContextMapperCallbackHandler handler = new AttributeToContextCallbackHandler(contextMapper);

        search(se, handler, processor);

        return handler.getList();
    }
}
