package com.atlassian.vcache.internal.core.service;

import com.atlassian.vcache.JvmCache;

import javax.annotation.Nonnull;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;

/**
 * Base implementation of {@link JvmCache} that performs locking to ensure that multi-thread consistency is
 * maintained. The example scenario is:
 * <ol>
 * <li>Thread 1 - begins call to {@link #get(Object, Supplier)}</li>
 * <li>Thread 1 - begins call to {@link Supplier#get()} as value is missing</li>
 * <li>Thread 2 - begins and ends a call to either {@link #remove(Object)} or {@link #removeAll()}</li>
 * <li>Thread 1 - ends calls to {@link Supplier#get()}</li>
 * <li>Thread 1 - ends call to {@link #get(Object, Supplier)}</li>
 * </ol>
 * Without locking, the cache will incorrectly contain an entry in the cache.
 *
 * @param <K> the key type
 * @param <V> the class type
 * @since 1.0.0
 */
public abstract class AbstractLockingJvmCache<K, V> implements JvmCache<K, V> {
    private final ConcurrentMap<K, OneShotLatch> barriers = new ConcurrentHashMap<>(16);
    private final Lock supplierLock;
    private final Lock removeAllLock;

    {
        // This needs to be fair to ensure that removeAll does not starve for a busy cache
        final ReadWriteLock supplierVsRemoveAllLock = new ReentrantReadWriteLock(true);
        supplierLock = supplierVsRemoveAllLock.readLock();
        removeAllLock = supplierVsRemoveAllLock.writeLock();
    }

    /**
     * Performs the {@link #get(Object, Supplier)} operation on the underlying cache.
     */
    @Nonnull
    protected abstract V decoratedGet(K key, Supplier<? extends V> supplier);

    /**
     * Performs the {@link #remove(Object)} operation on the underlying cache.
     */
    protected abstract void decoratedRemove(K key);

    /**
     * Performs the {@link #removeAll()} operation on the underlying cache.
     */
    protected abstract void decoratedRemoveAll();

    @Nonnull
    @Override
    public final V get(K key, Supplier<? extends V> supplier) {
        final boolean[] mustUnlock = new boolean[1];
        try {
            return decoratedGet(key, () -> {
                mustUnlock[0] = true;
                acquireLockFor(key);
                supplierLock.lock();
                return supplier.get();
            });
        } finally {
            if (mustUnlock[0]) {
                supplierLock.unlock();
                releaseLockFor(key);
            }
        }
    }

    @Override
    public final void remove(K key) {
        acquireLockFor(key);
        try {
            decoratedRemove(key);
        } finally {
            releaseLockFor(key);
        }
    }

    @Override
    public final void removeAll() {
        removeAllLock.lock();
        try {
            decoratedRemoveAll();
        } finally {
            removeAllLock.unlock();
        }
    }

    private OneShotLatch acquireLockFor(@Nonnull K key) {
        final OneShotLatch barrier = new OneShotLatch();
        while (true) {
            final OneShotLatch existing = barriers.putIfAbsent(key, barrier);

            // successfully raised a new barrier
            if (existing == null) {
                return barrier;
            }

            // There is an existing barrier that happens-before us; wait for it to be released.
            // There is no need to attempt a remove or replace to evict it from the barriers map;
            // The thread that owned it would have taken care of that before releasing it.
            existing.await();
        }
    }

    private void releaseLockFor(K key) {
        final OneShotLatch barrier = barriers.get(key);
        if (barrier != null && barrier.isHeldByCurrentThread()) {
            barriers.remove(key);
            barrier.release();
        }
    }
}
