package com.atlassian.crowd.directory.rfc4519;

import com.atlassian.crowd.directory.NamedLdapEntity;
import com.atlassian.crowd.directory.RFC4519Directory;
import com.atlassian.crowd.directory.ldap.mapper.ContextMapperWithRequiredAttributes;
import com.atlassian.crowd.embedded.api.SearchRestriction;
import com.atlassian.crowd.exception.OperationFailedException;
import com.atlassian.crowd.model.group.ImmutableMembership;
import com.atlassian.crowd.model.group.Membership;
import com.atlassian.crowd.search.EntityDescriptor;
import com.atlassian.crowd.search.builder.QueryBuilder;
import com.atlassian.crowd.search.query.entity.EntityQuery;
import com.atlassian.crowd.search.query.entity.restriction.BooleanRestriction;
import com.atlassian.crowd.search.query.entity.restriction.BooleanRestrictionImpl;
import com.atlassian.crowd.search.query.entity.restriction.MatchMode;
import com.atlassian.crowd.search.query.entity.restriction.PropertyImpl;
import com.atlassian.crowd.search.query.entity.restriction.TermRestriction;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.naming.InvalidNameException;
import javax.naming.ldap.LdapName;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

/**
 * An {@link Iterable} over group {@link Membership}s optimised for the case
 * where we already have all or some of the DNs and names of possible users and sub-groups.
 *
 * Memberships are fetched while iterating (batched by the partition size)
 * If some dn to name mappings are missing, these are fetched in batches while iterating.
 */
public class RFC4519DirectoryMembershipsIterable implements Iterable<Membership> {
    private static final Logger logger = LoggerFactory.getLogger(RFC4519DirectoryMembershipsIterable.class);

    private static final int MISSING_NAMES_PARTITION_SIZE = Integer.getInteger("com.atlassian.crowd.directory.RFC4519DirectoryMembershipsIterable.MISSING_NAMES_PARTITION_SIZE", 1000);

    private final RFC4519Directory connector;
    private final Map<LdapName, Optional<String>> users; // value can be in 3 different states: null which means that the LdapName wasn't searched before,
    // empty optional which means the LdapName was searched but nothing found
    // not empty optional which means the LdapName was searched and found
    private final Map<LdapName, Optional<String>> groups; // same as above
    private final Map<LdapName, String> groupsToInclude;
    private final int membershipBatchSize;
    private final ContextMapperWithRequiredAttributes<LdapName> dnMapper;

    RFC4519DirectoryMembershipsIterable(RFC4519Directory connector,
                                        Map<LdapName, String> users, Map<LdapName, String> groups,
                                        Map<LdapName, String> groupsToInclude,
                                        int membershipBatchSize,
                                        ContextMapperWithRequiredAttributes<LdapName> dnMapper) {
        this.connector = connector;
        this.users = users.entrySet().stream().collect(Collectors.toMap(Entry::getKey, entry -> Optional.ofNullable(entry.getValue())));
        this.groups = groups.entrySet().stream().collect(Collectors.toMap(Entry::getKey, entry -> Optional.ofNullable(entry.getValue())));
        this.groupsToInclude = groupsToInclude;
        this.membershipBatchSize = membershipBatchSize;
        this.dnMapper = dnMapper;
    }

    @Override
    public Iterator<Membership> iterator() {
        final Iterable<List<Entry<LdapName, String>>> partitioned = Iterables.partition(groupsToInclude.entrySet(), membershipBatchSize);
        return StreamSupport.stream(partitioned.spliterator(), false).map(partition -> {
            try {
                return getMemberships(partition);
            } catch (OperationFailedException e) {
                throw new Membership.MembershipIterationException(e);
            }
        })
                .flatMap(iterable -> StreamSupport.stream(iterable.spliterator(), false))
                .iterator();
    }

    private Iterable<Membership> getMemberships(Collection<Entry<LdapName, String>> groups) throws OperationFailedException {
        if (!groups.iterator().hasNext()) {
            return Collections.emptyList();
        }
        final List<MembershipHolder> memberships = searchChildrenDns(groups);

        lookupMissingNames(memberships);

        return resolveMemberships(memberships);
    }

    protected void lookupMissingNames(List<MembershipHolder> memberships) throws OperationFailedException {
        final Set<LdapName> missingLdapNames = findMissingEntriesInCacheAndCreateRestrictionsForThem(memberships);

        if (!missingLdapNames.isEmpty()) {
            updateCache(missingLdapNames);
        }
    }

    private List<MembershipHolder> searchChildrenDns(Collection<Entry<LdapName, String>> groups) {
        long start = System.currentTimeMillis();
        logger.info("Searching for children of {} groups", groups.size());

        final List<MembershipHolder> memberships = new ArrayList<>();
        for (final Entry<LdapName, String> groupDnToName : groups) {
            try {
                memberships.add(new MembershipHolder(groupDnToName.getValue(), ImmutableSet.copyOf(connector.findDirectMembersOfGroup(groupDnToName.getKey(), dnMapper))));
            } catch (OperationFailedException e) {
                throw new Membership.MembershipIterationException(e);
            }
        }
        logger.info("Found {} children for {} groups in {} ms", countMembershipsChildren(memberships), groups.size(), System.currentTimeMillis() - start);
        return memberships;
    }

    @VisibleForTesting
    long countMembershipsChildren(final List<MembershipHolder> memberships){
        return memberships.stream()
                .map(MembershipHolder::getChildren)
                .mapToLong(Collection::size).sum();
    }

    private Set<LdapName> findMissingEntriesInCacheAndCreateRestrictionsForThem(List<MembershipHolder> memberships) {
        return memberships.stream().map(holder -> holder.getChildren().stream()
                .filter(child -> !users.containsKey(child) && !groups.containsKey(child))
                .collect(Collectors.toList())).flatMap(Collection::stream)
                .collect(Collectors.toSet());
    }

    @SuppressFBWarnings(value = "LDAP_INJECTION", justification = "No user input - the String was just retrieved from the LdapName")
    private void updateCache(Collection<LdapName> missingNames) throws OperationFailedException {
        long start = System.currentTimeMillis();
        logger.info("Fetching details for {} entities for membership resolution", missingNames.size());

        final Set<TermRestriction<String>> restrictionsForMissingNames = missingNames.stream().map(n -> new TermRestriction<>(new PropertyImpl<>("distinguishedName", String.class), MatchMode.EXACTLY_MATCHES, n.toString()))
                .collect(Collectors.toSet());

        final Collection<NamedLdapEntity> userMembers = searchUsers(restrictionsForMissingNames, MISSING_NAMES_PARTITION_SIZE);

        final Collection<String> usersDns = getDns(userMembers);
        final Collection<TermRestriction<String>> restrictionsForChildrenGroups = filterOutRestrictionsForDns(restrictionsForMissingNames, usersDns);
        final Collection<NamedLdapEntity> groupsMembers = searchGroups(restrictionsForChildrenGroups, MISSING_NAMES_PARTITION_SIZE);

        for (TermRestriction<String> restrictionForNotFoundChild : filterOutRestrictionsForDns(restrictionsForChildrenGroups, getDns(groupsMembers))) {
            try {
                users.put(new LdapName(restrictionForNotFoundChild.getValue()), Optional.empty());
                groups.put(new LdapName(restrictionForNotFoundChild.getValue()), Optional.empty());
            } catch (InvalidNameException e) {
                throw new OperationFailedException(e);
            }
        }
        userMembers.forEach(n -> users.put(n.getDn(), Optional.of(n.getName())));
        groupsMembers.forEach(n -> groups.put(n.getDn(), Optional.of(n.getName())));
        logger.debug("Updating cache took {} ms", System.currentTimeMillis() - start);
    }

    private List<TermRestriction<String>> filterOutRestrictionsForDns(Collection<TermRestriction<String>> restrictions, Collection<String> dns) {
        return restrictions.stream()
                .filter(r -> !dns.contains(r.getValue()))
                .collect(Collectors.toList());
    }

    private Collection<NamedLdapEntity> searchUsers(Collection<TermRestriction<String>> restrictions, final int partitionSize) throws OperationFailedException {
        logger.debug("Searching for {} users in directory", restrictions.size());
        final long start = System.currentTimeMillis();

        final List<NamedLdapEntity> result = new ArrayList<>();
        for (final List<TermRestriction<String>> batchedRestrictions : Iterables.partition(restrictions, partitionSize)) {
            final EntityQuery<String> query = QueryBuilder.queryFor(String.class, EntityDescriptor.user())
                    .with(prepareBooleanRestrictionForTermRestrictions(batchedRestrictions))
                    .returningAtMost(EntityQuery.ALL_RESULTS);
            if (logger.isTraceEnabled()) {
                logger.trace("Searching user objects using query {}", query.toString());
            }
            result.addAll(connector.searchUserObjects(query, new NamedLdapEntity.NamedEntityMapper(connector.getLdapPropertiesMapper().getUserNameAttribute())));
        }

        logger.debug("Found {} users in directory in {} ms", result.size(), System.currentTimeMillis() - start);
        return result;
    }

    private Collection<NamedLdapEntity> searchGroups(Collection<TermRestriction<String>> restrictions, final int partitionSize) throws OperationFailedException {
        logger.debug("Searching for {} groups in directory", restrictions.size());
        final long start = System.currentTimeMillis();

        final List<NamedLdapEntity> result = new ArrayList<>();
        for (final List<TermRestriction<String>> batchedRestrictions : Iterables.partition(restrictions, partitionSize)) {
            result.addAll(connector.searchGroupObjects(QueryBuilder.queryFor(String.class, EntityDescriptor.group())
                    .with(prepareBooleanRestrictionForTermRestrictions(batchedRestrictions))
                    .returningAtMost(EntityQuery.ALL_RESULTS), new NamedLdapEntity.NamedEntityMapper(connector.getLdapPropertiesMapper().getGroupNameAttribute())));
        }

        logger.debug("Found {} groups in directory in {} ms", result.size(), System.currentTimeMillis() - start);
        return result;
    }

    private BooleanRestrictionImpl prepareBooleanRestrictionForTermRestrictions(List<TermRestriction<String>> batchedRestrictions) {
        return new BooleanRestrictionImpl(BooleanRestriction.BooleanLogic.OR, batchedRestrictions.toArray(new SearchRestriction[batchedRestrictions.size()]));
    }

    private Collection<Membership> resolveMemberships(Collection<MembershipHolder> memberships) {
        return memberships.stream().map(this::mapHolderToMembership).collect(Collectors.toList());
    }

    private ImmutableMembership mapHolderToMembership(MembershipHolder m) {
        return new ImmutableMembership(m.getName(),
                m.getChildren().stream().map(users::get).filter(Objects::nonNull).filter(Optional::isPresent).map(Optional::get).collect(Collectors.toList()),
                m.getChildren().stream().map(groups::get).filter(Objects::nonNull).filter(Optional::isPresent).map(Optional::get).collect(Collectors.toList()));
    }

    private Collection<String> getDns(Collection<NamedLdapEntity> ldapEntities) {
        return ldapEntities.stream().map(NamedLdapEntity::getDn).map(LdapName::toString).collect(Collectors.toList());
    }

    static class MembershipHolder {
        private final String name;
        private final Set<LdapName> children;

        MembershipHolder(String name, Set<LdapName> children) {
            this.name = name;
            this.children = Preconditions.checkNotNull(children, "children cannot be null");
        }

        public String getName() {
            return name;
        }

        public Set<LdapName> getChildren() {
            return children;
        }
    }

    @VisibleForTesting
    ContextMapperWithRequiredAttributes<LdapName> getDnMapper() {
        return dnMapper;
    }
}
