package com.atlassian.crowd.manager.directory.nestedgroups;

import com.atlassian.crowd.exception.OperationFailedException;
import com.atlassian.crowd.model.group.BaseImmutableGroup;
import com.atlassian.crowd.model.group.Group;
import com.atlassian.crowd.model.group.ImmutableGroup;
import com.google.common.base.Throwables;
import com.google.common.cache.Cache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;
import org.apache.commons.lang3.ArrayUtils;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * Caching wrapper for {@link MultipleGroupsProvider}
 */
public class CachedMultipleGroupsProvider implements MultipleGroupsProvider {
    private final Cache<String, String[]> subgroupsCache;
    private final Cache<String, Group> groupsCache;
    private final Function<String, String> idNormalizer;
    private final MultipleGroupsProvider provider;

    protected CachedMultipleGroupsProvider(Cache<String, String[]> subgroupsCache,
                                           Cache<String, Group> groupsCache,
                                           Function<String, String> idNormalizer,
                                           MultipleGroupsProvider provider) {
        this.subgroupsCache = subgroupsCache;
        this.groupsCache = groupsCache;
        this.idNormalizer = idNormalizer;
        this.provider = provider;
    }

    private List<Group> get(String groupName) {
        String[] names = subgroupsCache.getIfPresent(idNormalizer.apply(groupName));
        if (names == null) {
            return null;
        }
        List<Group> result = Stream.of(names).map(groupsCache::getIfPresent).collect(Collectors.toList());
        return result.contains(null) ? null : result;
    }

    private void addToCache(Set<String> names, ListMultimap<String, Group> results) {
        Set<String> normalizedMissingNames = names.stream().map(idNormalizer).collect(Collectors.toSet());
        Map<String, Group> groupMap = new HashMap<>();
        for (Map.Entry<String, Collection<Group>> entry : results.asMap().entrySet()) {
            List<String> subgroupNames = new ArrayList<>();
            for (Group group : entry.getValue()) {
                String normalizedName = idNormalizer.apply(group.getName());
                subgroupNames.add(normalizedName);
                groupMap.computeIfAbsent(normalizedName, ignore -> createImmutableGroup(group));
            }
            String normalizedParentName = idNormalizer.apply(entry.getKey());
            subgroupsCache.put(normalizedParentName, subgroupNames.toArray(ArrayUtils.EMPTY_STRING_ARRAY));
            normalizedMissingNames.remove(normalizedParentName);
        }
        for (String normalizedName : normalizedMissingNames) {
            subgroupsCache.put(normalizedName, ArrayUtils.EMPTY_STRING_ARRAY);
        }
        groupsCache.putAll(groupMap);
    }

    private Group createImmutableGroup(Group group) {
        if (group instanceof BaseImmutableGroup) {
            return group;
        }
        return ImmutableGroup.from(group);
    }

    @Override
    public ListMultimap<String, Group> getDirectlyRelatedGroups(Collection<String> names) throws OperationFailedException {
        try {
            ListMultimap<String, Group> results = ArrayListMultimap.create();
            Set<String> missingNames = new HashSet<>();
            for (String name : names) {
                List<Group> entry = get(name);
                if (entry != null) {
                    results.putAll(name, entry);
                } else {
                    missingNames.add(name);
                }
            }
            if (missingNames.isEmpty()) {
                return results;
            }
            ListMultimap<String, Group> fetchedResults = provider.getDirectlyRelatedGroups(missingNames);
            addToCache(missingNames, fetchedResults);
            results.putAll(fetchedResults);
            return results;
        } catch (ExecutionException e) {
            Throwables.propagateIfPossible(e.getCause(), OperationFailedException.class);
            throw new OperationFailedException(e.getCause());
        } catch (Exception e) {
            Throwables.propagateIfPossible(e, OperationFailedException.class);
            throw new OperationFailedException(e);
        }
    }
}
