package com.atlassian.vcache.internal.core.service;

import com.atlassian.vcache.RequestCache;
import com.atlassian.vcache.VCacheException;
import com.atlassian.vcache.internal.RequestContext;

import javax.annotation.Nullable;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.StampedLock;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static com.atlassian.vcache.internal.NameValidator.requireValidCacheName;
import static java.util.Objects.requireNonNull;

/**
 * Read optimised version of a {@link RequestCache}. This cache is locked with a {@link StampedLock} which
 * uses optimistic locking on the read path to improve performance when reading. Note that using this cache
 * in a write heavy use case will cause poorer performance than {@link DefaultRequestCache}.
 * <p>
 * For a good description of the semantics of a {@link StampedLock}
 * see: https://www.javaspecialists.eu/archive/Issue215.html
 *
 * @param <K> The key type
 * @param <V> The value type
 * @since 1.13.0
 */
class ReadOptimisedRequestCache<K, V> implements RequestCache<K, V> {

    // This ThreadLocal is used to implement lock re-entrance during a bulk get.
    private final ThreadLocal<Boolean> inWriteLock = ThreadLocal.withInitial(() -> false);

    private final String name;
    private final Supplier<RequestContext> contextSupplier;
    private final Duration lockTimeout;

    ReadOptimisedRequestCache(String name, Supplier<RequestContext> contextSupplier, Duration lockTimeout) {
        this.name = requireValidCacheName(name);
        this.contextSupplier = requireNonNull(contextSupplier);
        this.lockTimeout = requireNonNull(lockTimeout);
    }

    @Override
    public Optional<V> get(K key) {
        return Optional.ofNullable(withOptimisticReadLock(map -> map.get(key)));
    }

    @Override
    public V get(K key, Supplier<? extends V> supplier) {
        final Optional<V> value = get(key);
        return value.orElseGet(() -> {
            final V candidateValue = requireNonNull(supplier.get());
            final V existing = withWriteLock(map -> map.putIfAbsent(key, candidateValue));
            return existing == null ? candidateValue : existing;
        });
    }

    @Override
    public Map<K, V> getBulk(final Function<Set<K>, Map<K, V>> factory, final Iterable<K> keys) {
        // Function that will be applied with or without lock as required.
        final Function<Map<K, V>, Map<K, V>> cacheOps = map -> {
            final Map<K, Optional<V>> existingValues =
                    StreamSupport.stream(keys.spliterator(), false)
                            .distinct()
                            .collect(Collectors.toMap(Objects::requireNonNull, k -> Optional.ofNullable(map.get(k))));

            // Add known values to the grand result
            final Map<K, V> grandResult = existingValues.entrySet().stream()
                    .filter(e -> e.getValue().isPresent())
                    .collect(Collectors.toMap(
                            Map.Entry::getKey,
                            e -> e.getValue().get()));

            // Bail out if we have all the values
            if (grandResult.size() == existingValues.size()) {
                return grandResult;
            }

            // Sadly we now need to call the factory to create the missing values and then merge into the grand result.
            final Set<K> missingKeys = existingValues.entrySet().stream()
                    .filter(e -> !e.getValue().isPresent())
                    .map(Map.Entry::getKey)
                    .collect(Collectors.toSet());

            final Map<K, V> missingValues = factory.apply(missingKeys);
            FactoryUtils.verifyFactoryResult(missingValues, missingKeys);

            missingValues.forEach((key, value) -> {
                // Handle that another thread may have beaten us to the punch.
                final Optional<V> existing = Optional.ofNullable(map.putIfAbsent(key, value));
                grandResult.put(key, existing.orElse(value));
            });

            return grandResult;
        };

        // Allow bulk get lock to be re-entrant for the same cache on the same thread.
        if (inWriteLock.get()) {
            return cacheOps.apply(ensureDelegate().map);
        } else {
            try {
                inWriteLock.set(true);
                //noinspection ConstantConditions
                return withWriteLock(cacheOps);
            } finally {
                inWriteLock.set(false);
            }
        }
    }

    @Override
    public void put(K key, V value) {
        withWriteLock(map -> map.put(key, value));
    }

    @Override
    public Optional<V> putIfAbsent(K key, V value) {
        return Optional.ofNullable(withWriteLock(map -> map.putIfAbsent(key, value)));
    }

    @Override
    public boolean replaceIf(K key, V currentValue, V newValue) {
        //noinspection ConstantConditions
        return withWriteLock(map -> map.replace(requireNonNull(key), requireNonNull(currentValue), requireNonNull(newValue)));
    }

    @Override
    public boolean removeIf(K key, V value) {
        //noinspection ConstantConditions
        return withWriteLock(map -> map.remove(requireNonNull(key), requireNonNull(value)));
    }

    @Override
    public void remove(K key) {
        withWriteLock(map -> map.remove(key));
    }

    @Override
    public void removeAll() {
        withWriteLock(map -> {
            map.clear();
            return 0;
        });
    }

    @Override
    public String getName() {
        return name;
    }

    private MapAndLock<K, V> ensureDelegate() {
        final RequestContext requestContext = contextSupplier.get();
        return requestContext.computeIfAbsent(this, MapAndLock::new);
    }

    @Nullable
    private <R> R withOptimisticReadLock(Function<Map<K, V>, R> getV) {
        final MapAndLock<K, V> mapAndLock = ensureDelegate();
        final long optimisticStamp = mapAndLock.lock.tryOptimisticRead();
        final R value = getV.apply(mapAndLock.map);
        if (mapAndLock.lock.validate(optimisticStamp)) {
            return value;
        }
        final long readLockStamp;
        try {
            readLockStamp = mapAndLock.lock.tryReadLock(lockTimeout.toMillis(), TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            throw new VCacheException("Lock acquisition on cache interrupted.", e);
        }
        if (readLockStamp != 0) {
            try {
                return getV.apply(mapAndLock.map);
            } finally {
                mapAndLock.lock.unlock(readLockStamp);
            }
        }
        // The lock acquisition failed.
        throw new VCacheException("Failed to lock cache");
    }

    @Nullable
    private <R> R withWriteLock(Function<Map<K, V>, R> putV) {
        final MapAndLock<K, V> mapAndLock = ensureDelegate();
        long stamp = 0;
        try {
            stamp = mapAndLock.lock.tryWriteLock(lockTimeout.toMillis(), TimeUnit.MILLISECONDS);
            if (stamp != 0) {
                return putV.apply(mapAndLock.map);
            } else {
                throw new VCacheException("Could not acquire write lock");
            }
        } catch (InterruptedException e) {
            throw new VCacheException("Interrupted acquiring write lock", e);
        } finally {
            if (stamp != 0) {
                mapAndLock.lock.unlockWrite(stamp);
            }
        }
    }

    private static class MapAndLock<K, V> {
        final Map<K, V> map = new HashMap<>();
        final StampedLock lock = new StampedLock();
    }
}
