package com.atlassian.crowd.dao.membership.cache;

import com.atlassian.cache.Cache;
import com.atlassian.crowd.dao.direntity.DirectoryEntityResolver;
import com.atlassian.crowd.dao.direntity.LocallyCachedDirectoryEntityResolver;
import com.atlassian.crowd.embedded.impl.IdentifierUtils;
import com.atlassian.crowd.model.DirectoryEntities;
import com.atlassian.crowd.model.DirectoryEntity;
import com.atlassian.crowd.model.InternalDirectoryEntity;
import com.atlassian.crowd.util.cache.LocalCacheUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.springframework.transaction.support.TransactionSynchronization;
import org.springframework.transaction.support.TransactionSynchronizationManager;

import javax.annotation.Nullable;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ScheduledExecutorService;

/**
 * Membership cache implementation.
 * Internally there are multiple caches created - 1 per directory and query type pair.
 */
public class MembershipCache {
    private final CacheFactory cacheFactory;
    private final ConcurrentMap<QueryTypeCacheKey, Cache<String, List<String>>> caches;
    private final @Nullable
    DirectoryEntityResolver entityResolver;
    private final Set<QueryType> cacheableTypes;
    private final Duration cacheTtl;
    private final int groupMembershipCacheMax;
    private final int queryTypeInvalidationThreshold;

    /**
     * @param cacheFactory                   factory which will be used to create caches - 1 per directory and query type pair
     * @param cacheableTypes                 set of query types that should be cached
     * @param cacheTtl                       time for which entries are valid after write
     * @param groupMembershipCacheMax        maximum cache size (per directory and query type pair)
     * @param queryTypeInvalidationThreshold when number of invalidated keys goes above this threshold then all entries
     *                                       for the given query type will be invalidated; this allows to limit the number
     *                                       of invalidation messages sent to the cluster
     * @param entityResolver                 implementation of {@link DirectoryEntityResolver} to translate entity names to entities;
     *                                       there are following options:
     *                                       <ul>
     *                                       <li>Providing {@code null} will result in only names caching. Entity query
     *                                       results will never be returned from cache, however the names of the
     *                                       results of such query will be cached.</li>
     *                                       <li>Providing {@link LocallyCachedDirectoryEntityResolver}. The cache does not expire
     *                                       entities on modifications, so it may return stale data.</li>
     *                                       <li>Providing custom implementation of {@link DirectoryEntityResolver}. This is
     *                                       especially useful if product already caches entities by name.</li>
     *                                       </ul>
     * @param cleanupPool                    thread pool for cache cleanup task
     */
    public MembershipCache(CacheFactory cacheFactory,
                           Set<QueryType> cacheableTypes,
                           Duration cacheTtl,
                           int groupMembershipCacheMax,
                           int queryTypeInvalidationThreshold,
                           @Nullable DirectoryEntityResolver entityResolver,
                           ScheduledExecutorService cleanupPool) {
        this.cacheFactory = cacheFactory;
        this.cacheableTypes = ImmutableSet.copyOf(cacheableTypes);
        this.cacheTtl = cacheTtl;
        this.groupMembershipCacheMax = groupMembershipCacheMax;
        this.queryTypeInvalidationThreshold = queryTypeInvalidationThreshold;
        this.caches = LocalCacheUtils.createExpiringAfterAccessMap(cacheTtl, cleanupPool);
        this.entityResolver = entityResolver;
    }

    public Set<QueryType> getCacheableTypes() {
        return cacheableTypes;
    }

    public void invalidateCache(long directoryId) {
        getCacheInvalidation().addInvalidation(directoryId);
    }

    public void invalidateCache(long directoryId, QueryType queryType) {
        getCacheInvalidation().addInvalidation(directoryId, queryType);
    }

    public void invalidateCache(long directoryId, QueryType queryType, String key) {
        getCacheInvalidation().addInvalidation(directoryId, queryType, key);
    }

    public <T> void put(long directoryId, QueryType queryType, String key, List<T> data) {
        QueryTypeCacheKey cacheKey = new QueryTypeCacheKey(directoryId, queryType);
        if (!isInvalidated(cacheKey, key)) {
            if (entityResolver != null && !data.isEmpty()) {
                Object first = data.get(0);
                if (first instanceof DirectoryEntity && supports(first.getClass())) {
                    entityResolver.putAll((List<DirectoryEntity>) data);
                }
            }
            getOrCreateCache(cacheKey)
                    .put(IdentifierUtils.toLowerCase(key), namesOf(data));
        }
    }

    private List<String> namesOf(List<?> list) {
        if (list.isEmpty() || list.get(0) instanceof String) {
            return ImmutableList.copyOf((List<String>) list);
        } else {
            return ImmutableList.copyOf(DirectoryEntities.namesOf((List<DirectoryEntity>) list));
        }
    }

    @Nullable
    public <T> List<T> get(long directoryId, QueryType queryType, String key, Class<T> returnType) {
        List<String> names = getNames(directoryId, queryType, key);
        if (returnType == String.class) {
            return (List<T>) names;
        }
        if (names == null || !supports(returnType)) {
            return null;
        }
        return (List<T>) entityResolver.resolveAllOrNothing(directoryId, names,
                (Class<DirectoryEntity>) returnType);
    }

    @Nullable
    public List<String> getNames(long directoryId, QueryType queryType, String key) {
        QueryTypeCacheKey cacheKey = new QueryTypeCacheKey(directoryId, queryType);
        if (isInvalidated(cacheKey, key)) {
            return null;
        }
        return getOrCreateCache(cacheKey).get(IdentifierUtils.toLowerCase(key));
    }

    protected void processInvalidations(CacheInvalidations invalidations) {
        invalidations.getQueryTypesInvalidations()
                .stream()
                .map(this::getOrCreateCache)
                .forEach(Cache::removeAll);

        for (Map.Entry<QueryTypeCacheKey, Set<String>> entry : invalidations.getKeyInvalidations().entrySet()) {
            entry.getValue().forEach(getOrCreateCache(entry.getKey())::remove);
        }
    }

    protected Cache<String, List<String>> getOrCreateCache(QueryTypeCacheKey cacheKey) {
        Preconditions.checkArgument(cacheableTypes.contains(cacheKey.getQueryType()));
        return caches.computeIfAbsent(cacheKey, k -> cacheFactory.createCache(k, cacheTtl, groupMembershipCacheMax));
    }

    ThreadLocal<CacheInvalidations> cacheInvalidationThreadLocal = new ThreadLocal<>();

    protected CacheInvalidations getCacheInvalidation() {
        CacheInvalidations invalidations = cacheInvalidationThreadLocal.get();
        if (invalidations == null) {
            final CacheInvalidations newInvalidations = new CacheInvalidations(cacheableTypes, queryTypeInvalidationThreshold);
            invalidations = newInvalidations;
            cacheInvalidationThreadLocal.set(invalidations);
            TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
                @Override
                public void afterCompletion(int status) {
                    cacheInvalidationThreadLocal.remove();
                    if (status == TransactionSynchronization.STATUS_COMMITTED
                            || status == TransactionSynchronization.STATUS_UNKNOWN) {
                        processInvalidations(newInvalidations);
                    }
                }

                // suspend and resume implemented to handle nested transactions
                @Override
                public void suspend() {
                    cacheInvalidationThreadLocal.remove();
                }

                @Override
                public void resume() {
                    cacheInvalidationThreadLocal.set(newInvalidations);
                }
            });
        }
        return invalidations;
    }

    protected boolean isInvalidated(QueryTypeCacheKey cacheKey, String key) {
        CacheInvalidations invalidations = cacheInvalidationThreadLocal.get();
        return invalidations != null && invalidations.isInvalidated(cacheKey, key);
    }

    public void clear() {
        caches.values().forEach(Cache::removeAll);
    }

    public void clear(long directoryId) {
        for (QueryType queryType : cacheableTypes) {
            getOrCreateCache(new QueryTypeCacheKey(directoryId, queryType)).removeAll();
        }
    }

    public int cacheCount() {
        return caches.size();
    }

    public boolean supports(Class<?> resultClass) {
        return entityResolver == null ? resultClass == String.class : !InternalDirectoryEntity.class.isAssignableFrom(resultClass);
    }
}
