package com.atlassian.vcache.internal.core.service;

import com.atlassian.vcache.PutPolicy;
import com.atlassian.vcache.TransactionalExternalCache;
import com.atlassian.vcache.internal.ExternalCacheExceptionListener;
import com.atlassian.vcache.internal.MetricLabel;
import com.atlassian.vcache.internal.RequestContext;
import com.atlassian.vcache.internal.core.TransactionControl;
import com.atlassian.vcache.internal.core.metrics.CacheType;
import com.atlassian.vcache.internal.core.metrics.MetricsRecorder;

import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static com.atlassian.vcache.internal.core.VCacheCoreUtils.isEmpty;
import static java.util.Objects.requireNonNull;

/**
 * Provides operations common for {@link com.atlassian.vcache.TransactionalExternalCache} instances.
 *
 * @param <V> the value type
 * @since 1.0.0
 */
public abstract class AbstractTransactionalExternalCache<V>
        extends AbstractExternalCache<V>
        implements TransactionalExternalCache<V>, TransactionControl {

    protected final Supplier<RequestContext> contextSupplier;
    protected final MetricsRecorder metricsRecorder;

    protected AbstractTransactionalExternalCache(
            String name,
            Supplier<RequestContext> contextSupplier,
            MetricsRecorder metricsRecorder,
            Duration lockTimeout,
            ExternalCacheExceptionListener externalCacheExceptionListener) {
        super(name, lockTimeout, externalCacheExceptionListener);
        this.contextSupplier = requireNonNull(contextSupplier);
        this.metricsRecorder = requireNonNull(metricsRecorder);
    }

    /**
     * Performs a direct get operation against the external cache using the supplied external key.
     */
    protected abstract Optional<V> directGet(String externalKey);

    /**
     * Performs a direct bulk get operation against the external cache using the supplied external keys.
     */
    protected abstract Map<String, Optional<V>> directGetBulk(Set<String> externalKeys);

    @Override
    public final CompletionStage<Optional<V>> get(String internalKey) {
        return perform(() -> {
            final AbstractExternalCacheRequestContext<V> cacheContext = ensureCacheContext();

            // Check if we have recorded a value already
            final Optional<Optional<V>> recordedValue = cacheContext.getValueRecorded(internalKey);

            return recordedValue.orElseGet(() -> {
                // Check if a removeAll() has happened
                if (cacheContext.hasRemoveAll()) {
                    return Optional.empty();
                }

                // Now check externally
                final String externalKey = cacheContext.externalEntryKeyFor(internalKey);
                metricsRecorder.record(name, CacheType.EXTERNAL, MetricLabel.NUMBER_OF_REMOTE_GET, 1);
                final Optional<V> externalValue = directGet(externalKey);
                cacheContext.recordValue(internalKey, externalValue);

                return externalValue;
            });
        });
    }

    @Override
    public final CompletionStage<V> get(String internalKey, Supplier<V> supplier) {
        return perform(() -> {
            final AbstractExternalCacheRequestContext<V> cacheContext = ensureCacheContext();

            final String externalKey = cacheContext.externalEntryKeyFor(internalKey);

            // Check if we have recorded a value already
            final Optional<Optional<V>> recordedValue = cacheContext.getValueRecorded(internalKey);
            if (recordedValue.isPresent()) {
                if (recordedValue.get().isPresent()) {
                    return recordedValue.get().get();
                }
                // There was a remove, so need to re-create
                getLogger().trace("Cache {}, creating candidate for key {}", name, internalKey);
                return handleCreation(internalKey, supplier);
            }

            // If a transactional cache has fired removeAll, then we always consider the remote to be empty.
            if (!cacheContext.hasRemoveAll()) {
                metricsRecorder.record(name, CacheType.EXTERNAL, MetricLabel.NUMBER_OF_REMOTE_GET, 1);
                final Optional<V> result = directGet(externalKey);
                if (result.isPresent()) {
                    // A valid value exists in the external cache
                    cacheContext.recordValue(internalKey, result);
                    return result.get();
                }
            }
            getLogger().trace("Cache {}, creating candidate for key {}", name, internalKey);
            return handleCreation(internalKey, supplier);
        });
    }

    @Override
    public final CompletionStage<Map<String, Optional<V>>> getBulk(Iterable<String> internalKeys) {
        return perform(() -> {
            if (isEmpty(internalKeys)) {
                return new HashMap<>();
            }

            // Get the recorded values first
            final AbstractExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
            final Map<String, Optional<V>> grandResult = checkValuesRecorded(internalKeys);

            // Calculate the externalKeys for the entries that are missing
            final Set<String> missingExternalKeys = StreamSupport.stream(internalKeys.spliterator(), false)
                    .filter(k -> !grandResult.containsKey(k))
                    .map(cacheContext::externalEntryKeyFor)
                    .collect(Collectors.toSet());

            if (missingExternalKeys.isEmpty()) {
                getLogger().trace("Cache {}: getBulk(): have all the requested entries cached", name);
                return grandResult;
            }
            getLogger().trace("Cache {}: getBulk(): not cached {} requested entries", name, missingExternalKeys.size());

            // Get the missing values.
            metricsRecorder.record(name, CacheType.EXTERNAL, MetricLabel.NUMBER_OF_REMOTE_GET, 1);
            final Map<String, Optional<V>> candidateValues = directGetBulk(missingExternalKeys);

            return candidateValues.entrySet().stream().collect(
                    () -> grandResult,
                    (m, e) -> {
                        final Optional<V> result = e.getValue();
                        cacheContext.recordValue(cacheContext.internalEntryKeyFor(e.getKey()), result);
                        m.put(cacheContext.internalEntryKeyFor(e.getKey()), result);
                    },
                    Map::putAll
            );
        });
    }

    @Override
    public final CompletionStage<Map<String, V>> getBulk(
            Function<Set<String>, Map<String, V>> factory, Iterable<String> internalKeys) {
        return perform(() -> {
            if (isEmpty(internalKeys)) {
                return new HashMap<>();
            }

            // Get the recorded values first and collect the ones that have values.
            final AbstractExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
            //noinspection OptionalGetWithoutIsPresent
            final Map<String, V> grandResult = checkValuesRecorded(internalKeys).entrySet().stream()
                    .filter(e -> e.getValue().isPresent())
                    .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get()));

            // Calculate the candidate externalKeys for the entries that are missing
            final Set<String> candidateMissingExternalKeys = StreamSupport.stream(internalKeys.spliterator(), false)
                    .filter(k -> !grandResult.containsKey(k))
                    .map(cacheContext::externalEntryKeyFor)
                    .collect(Collectors.toSet());

            // Bail out if we have all the entries requested
            if (candidateMissingExternalKeys.isEmpty()) {
                getLogger().trace("Cache {}: getBulk(Function): had all the requested entries cached", name);
                return grandResult;
            }
            getLogger().trace("Cache {}: getBulk(Function): checking external cache for {} keys",
                    name, candidateMissingExternalKeys.size());

            final Map<String, V> missingValues = handleCreation(factory, candidateMissingExternalKeys);
            cacheContext.recordValues(missingValues);
            grandResult.putAll(missingValues);

            return grandResult;
        });
    }

    @Override
    public final void put(String internalKey, V value, PutPolicy policy) {
        final AbstractExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
        cacheContext.recordPut(internalKey, value, policy);
    }

    @Override
    public final void remove(Iterable<String> internalKeys) {
        final AbstractExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
        cacheContext.recordRemove(internalKeys);
    }

    @Override
    public final void removeAll() {
        final AbstractExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
        cacheContext.recordRemoveAll();
    }

    @Override
    public final boolean transactionDiscard() {
        final RequestContext requestContext = contextSupplier.get();
        final Optional<AbstractExternalCacheRequestContext<V>> cacheRequestContext = requestContext.get(this);

        if (!cacheRequestContext.isPresent()) {
            // there are no pending operations
            return false;
        }

        final boolean hasPendingOperations = cacheRequestContext.get().hasPendingOperations();
        cacheRequestContext.get().forgetAll();
        return hasPendingOperations;
    }

    private Map<String, Optional<V>> checkValuesRecorded(Iterable<String> internalKeys) {
        final AbstractExternalCacheRequestContext<V> cacheContext = ensureCacheContext();

        final Map<String, Optional<V>> result = new HashMap<>();

        internalKeys.forEach(k -> {
            final Optional<Optional<V>> valueRecorded = cacheContext.getValueRecorded(k);
            if (valueRecorded.isPresent()) {
                result.put(k, valueRecorded.get());
            } else if (cacheContext.hasRemoveAll()) {
                result.put(k, Optional.empty());
            }
        });

        return result;
    }

    private V handleCreation(String internalKey, Supplier<V> supplier) throws ExecutionException, InterruptedException {
        final AbstractExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
        final V suppliedValue = requireNonNull(supplier.get());
        cacheContext.recordPutPolicy(internalKey, suppliedValue, PutPolicy.ADD_ONLY);
        cacheContext.recordValue(internalKey, Optional.of(suppliedValue));
        return suppliedValue;
    }

    private Map<String, V> handleCreation(Function<Set<String>, Map<String, V>> factory, Set<String> externalKeys)
            throws ExecutionException, InterruptedException {
        // Get the missing values from the external cache.
        final AbstractExternalCacheRequestContext<V> cacheContext = ensureCacheContext();

        // Need to handle if removeAll has been performed OR a remove has been done, and hence not check remotely.
        // Otherwise, need to check remotely.
        final Map<String, V> grandResult = new HashMap<>();
        final Set<String> missingExternalKeys = fillInKnownValuesFromBackingCache(cacheContext, externalKeys, grandResult);

        if (!missingExternalKeys.isEmpty()) {
            getLogger().trace("Cache {}: getBulk(Function): calling factory to create {} values",
                    name, missingExternalKeys.size());
            // Okay, need to get the missing values and mapping from externalKeys to internalKeys
            final Set<String> missingInternalKeys = Collections.unmodifiableSet(
                    missingExternalKeys.stream().map(cacheContext::internalEntryKeyFor).collect(Collectors.toSet()));
            final Map<String, V> missingValues = factory.apply(missingInternalKeys);
            FactoryUtils.verifyFactoryResult(missingValues, missingInternalKeys);

            // Okay, got the missing values, now need to record adding them
            missingValues.entrySet().forEach(e -> put(e.getKey(), e.getValue(), PutPolicy.ADD_ONLY));

            grandResult.putAll(missingValues);
        }

        return grandResult;
    }

    private Set<String> fillInKnownValuesFromBackingCache(
            AbstractExternalCacheRequestContext<V> cacheContext, Set<String> externalKeys, Map<String, V> grandResult) {
        final Set<String> missingExternalKeys;

        if (cacheContext.hasRemoveAll()) {
            missingExternalKeys = externalKeys;
        } else {
            // Initial list of missing keys are the keys that have a recorded remove operation against them.
            missingExternalKeys = externalKeys.stream()
                    .filter(k -> {
                        final Optional<Optional<V>> valueRecorded =
                                cacheContext.getValueRecorded(cacheContext.internalEntryKeyFor(k));
                        // If a value is recorded, it has to be a remove, otherwise not be passed to this method.
                        return valueRecorded.isPresent();
                    })
                    .collect(Collectors.toSet());

            // Calculate list of keys we need to check for, as they may exist remotely
            final Set<String> externalKeysNotRemoved = externalKeys.stream()
                    .filter(k -> !missingExternalKeys.contains(k))
                    .collect(Collectors.toSet());

            if (!externalKeysNotRemoved.isEmpty()) {
                metricsRecorder.record(name, CacheType.EXTERNAL, MetricLabel.NUMBER_OF_REMOTE_GET, 1);
                final Map<String, Optional<V>> candidateValues = directGetBulk(externalKeysNotRemoved);

                candidateValues.entrySet().forEach(e -> {
                    if (e.getValue().isPresent()) {
                        grandResult.put(
                                cacheContext.internalEntryKeyFor(e.getKey()),
                                e.getValue().get());
                    } else {
                        missingExternalKeys.add(e.getKey());
                    }
                });
            }
        }

        return missingExternalKeys;
    }
}
