package com.atlassian.vcache.internal.memcached;

import com.atlassian.marshalling.api.MarshallingPair;
import com.atlassian.vcache.CasIdentifier;
import com.atlassian.vcache.DirectExternalCache;
import com.atlassian.vcache.ExternalCacheException;
import com.atlassian.vcache.ExternalCacheSettings;
import com.atlassian.vcache.IdentifiedValue;
import com.atlassian.vcache.PutPolicy;
import com.atlassian.vcache.internal.RequestContext;
import com.atlassian.vcache.internal.core.DefaultIdentifiedValue;
import com.atlassian.vcache.internal.core.ExternalCacheKeyGenerator;
import com.atlassian.vcache.internal.core.VCacheCoreUtils;
import com.atlassian.vcache.internal.core.service.AbstractExternalCache;
import com.atlassian.vcache.internal.core.service.FactoryUtils;
import com.atlassian.vcache.internal.core.service.VersionedExternalCacheRequestContext;
import com.google.common.annotations.VisibleForTesting;
import net.spy.memcached.CASResponse;
import net.spy.memcached.CASValue;
import net.spy.memcached.MemcachedClientIF;
import net.spy.memcached.OperationTimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Future;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static com.atlassian.vcache.internal.core.VCacheCoreUtils.isEmpty;
import static com.atlassian.vcache.internal.core.VCacheCoreUtils.marshall;
import static com.atlassian.vcache.internal.core.VCacheCoreUtils.unmarshall;
import static com.atlassian.vcache.internal.memcached.MemcachedUtils.expiryTime;
import static com.atlassian.vcache.internal.memcached.MemcachedUtils.identifiedValueFrom;
import static com.atlassian.vcache.internal.memcached.MemcachedUtils.putOperationForPolicy;
import static com.atlassian.vcache.internal.memcached.MemcachedUtils.safeExtractId;
import static java.util.Objects.requireNonNull;

/**
 * Implementation of the {@link DirectExternalCache} that uses Memcached.
 *
 * @param <V> the value type
 * @since 1.0
 */
class MemcachedDirectExternalCache<V>
        extends AbstractExternalCache<V>
        implements DirectExternalCache<V> {
    private static final Logger log = LoggerFactory.getLogger(MemcachedDirectExternalCache.class);

    private final Supplier<MemcachedClientIF> clientSupplier;
    private final Supplier<RequestContext> contextSupplier;
    private final ExternalCacheKeyGenerator keyGenerator;
    private final MarshallingPair<V> valueMarshalling;
    private final int ttlSeconds;

    MemcachedDirectExternalCache(
            MemcachedVCacheServiceSettings serviceSettings,
            Supplier<RequestContext> contextSupplier,
            ExternalCacheKeyGenerator keyGenerator,
            String name,
            MarshallingPair<V> valueMarshalling,
            ExternalCacheSettings settings) {
        super(name, serviceSettings.getLockTimeout(), serviceSettings.getExternalCacheExceptionListener());
        this.clientSupplier = requireNonNull(serviceSettings.getClientSupplier());
        this.contextSupplier = requireNonNull(contextSupplier);
        this.keyGenerator = requireNonNull(keyGenerator);
        this.valueMarshalling = requireNonNull(valueMarshalling);
        this.ttlSeconds = VCacheCoreUtils.roundUpToSeconds(settings.getDefaultTtl().get());
    }

    @Override
    public CompletionStage<Optional<V>> get(String internalKey) {
        return perform(() -> {
            final String externalKey = buildExternalKey(internalKey);
            return unmarshall((byte[]) clientSupplier.get().get(externalKey), valueMarshalling);
        });
    }

    @Override
    public CompletionStage<V> get(String internalKey, Supplier<V> supplier) {
        return perform(() -> {
            final String externalKey = buildExternalKey(internalKey);
            final Optional<V> existingValue = unmarshall((byte[]) clientSupplier.get().get(externalKey), valueMarshalling);
            if (existingValue.isPresent()) {
                return existingValue.get();
            }

            log.trace("Cache {}, creating candidate for key {}", name, internalKey);
            final V candidateValue = requireNonNull(supplier.get());
            final byte[] candidateValueBytes = valueMarshalling.getMarshaller().marshallToBytes(candidateValue);

            // Loop until either able to add the candidate value, or retrieve one that has been added by another thread
            for (; ; ) {
                final Future<Boolean> addOp = clientSupplier.get().add(externalKey, expiryTime(ttlSeconds), candidateValueBytes);
                if (addOp.get()) {
                    // I break here, rather than just return, due to battling with the compiler. Unless written
                    // this way, the compiler will not allow the lambda structure.
                    break;
                }

                log.info("Cache {}, unable to add candidate for key {}, retrieve what was added", name, internalKey);
                final Optional<V> otherAddedValue = unmarshall((byte[]) clientSupplier.get().get(externalKey), valueMarshalling);
                if (otherAddedValue.isPresent()) {
                    return otherAddedValue.get();
                }

                log.info("Cache {}, unable to retrieve recently added candidate for key {}, looping", name, internalKey);
            }
            return candidateValue;
        });
    }

    @Override
    public CompletionStage<Optional<IdentifiedValue<V>>> getIdentified(String internalKey) {
        return perform(() -> {
            final String externalKey = buildExternalKey(internalKey);
            final CASValue<Object> casValue = clientSupplier.get().gets(externalKey);
            if (casValue == null) {
                return Optional.empty();
            }

            final CasIdentifier identifier = new MemcachedCasIdentifier(casValue.getCas());
            final IdentifiedValue<V> iv = new DefaultIdentifiedValue<>(
                    identifier, valueMarshalling.getUnmarshaller().unmarshallFrom((byte[]) casValue.getValue()));
            return Optional.of(iv);
        });
    }

    @Override
    public CompletionStage<IdentifiedValue<V>> getIdentified(String internalKey, Supplier<V> supplier) {
        return perform(() -> {
            final String externalKey = buildExternalKey(internalKey);
            final CASValue<Object> existingCasValue = clientSupplier.get().gets(externalKey);

            if (existingCasValue != null) {
                final CasIdentifier identifier = new MemcachedCasIdentifier(existingCasValue.getCas());
                final IdentifiedValue<V> iv = new DefaultIdentifiedValue<>(
                        identifier, valueMarshalling.getUnmarshaller().unmarshallFrom((byte[]) existingCasValue.getValue()));
                return iv;
            }

            log.trace("Cache {}, creating candidate for key {}", name, internalKey);
            final V candidateValue = requireNonNull(supplier.get());
            final byte[] candidateValueBytes = valueMarshalling.getMarshaller().marshallToBytes(candidateValue);

            // Loop until either able to add the candidate value, or retrieve one that has been added by another thread.
            for (; ; ) {
                final Future<Boolean> addOp = clientSupplier.get().add(externalKey, expiryTime(ttlSeconds), candidateValueBytes);
                if (!addOp.get()) {
                    log.trace("Cache {}, unable to add candidate for key {}", name, internalKey);
                }

                // Regardless of whether able to add an entry or not, need to retrieve to get the CAS value.
                log.trace("Cache {}, retrieving the candidate for key {}", name, internalKey);
                final CASValue<Object> otherAddedCasValue = clientSupplier.get().gets(externalKey);
                if (otherAddedCasValue != null) {
                    final CasIdentifier identifier = new MemcachedCasIdentifier(otherAddedCasValue.getCas());
                    final IdentifiedValue<V> iv = new DefaultIdentifiedValue<>(
                            identifier, valueMarshalling.getUnmarshaller().unmarshallFrom((byte[]) otherAddedCasValue.getValue()));
                    return iv;
                }

                log.info("Cache {}, unable to retrieve recently added candidate for key {}, looping", name, internalKey);
            }
        });
    }

    @Override
    public CompletionStage<Map<String, Optional<V>>> getBulk(Iterable<String> internalKeys) {
        return perform(() -> {
            if (isEmpty(internalKeys)) {
                return new HashMap<>();
            }

            // De-duplicate the keys and calculate the externalKeys
            final VersionedExternalCacheRequestContext cacheContext = ensureCacheContext();

            final Set<String> externalKeys = StreamSupport.stream(internalKeys.spliterator(), false)
                    .map(cacheContext::externalEntryKeyFor)
                    .collect(Collectors.toSet());

            // Returns map of keys that contain values, so need to handle the missing ones
            final Map<String, Object> haveValues = clientSupplier.get().getBulk(externalKeys);

            return externalKeys.stream().collect(Collectors.toMap(
                    cacheContext::internalEntryKeyFor,
                    k -> unmarshall((byte[]) haveValues.get(k), valueMarshalling)));
        });
    }

    @Override
    public CompletionStage<Map<String, V>> getBulk(Function<Set<String>, Map<String, V>> factory, Iterable<String> internalKeys) {
        return perform(() -> {
            if (isEmpty(internalKeys)) {
                return new HashMap<>();
            }

            // De-duplicate the keys and calculate the externalKeys
            final VersionedExternalCacheRequestContext cacheContext = ensureCacheContext();

            final Set<String> externalKeys = Collections.unmodifiableSet(
                    StreamSupport.stream(internalKeys.spliterator(), false)
                            .map(cacheContext::externalEntryKeyFor)
                            .collect(Collectors.toSet()));

            // Returns map of keys that contain values, so need to calculate the
            // missing ones
            final Map<String, Object> haveValues = clientSupplier.get().getBulk(externalKeys);
            log.trace("{} of {} entries have values", haveValues.size(), externalKeys.size());
            final Set<String> missingExternalKeys = new HashSet<>(externalKeys);
            missingExternalKeys.removeAll(haveValues.keySet());

            // Add the existing values to the grand result
            final Map<String, V> grandResult = haveValues.entrySet().stream().collect(Collectors.toMap(
                    e -> cacheContext.internalEntryKeyFor(e.getKey()),
                    e -> unmarshall((byte[]) e.getValue(), valueMarshalling).get()
            ));

            if (!missingExternalKeys.isEmpty()) {
                // Okay, need to get the missing values and mapping from externalKeys to internalKeys
                final Set<String> missingInternalKeys = Collections.unmodifiableSet(
                        missingExternalKeys.stream().map(cacheContext::internalEntryKeyFor).collect(Collectors.toSet()));
                final Map<String, V> missingValues = factory.apply(missingInternalKeys);
                FactoryUtils.verifyFactoryResult(missingValues, missingInternalKeys);

                // Okay, got the missing values, now need to add them to Memcached
                final Map<String, Future<Boolean>> internalKeyToFutureMap = missingValues.entrySet().stream().collect(Collectors.toMap(
                        Map.Entry::getKey,
                        e -> clientSupplier.get().set(
                                cacheContext.externalEntryKeyFor(e.getKey()), expiryTime(ttlSeconds), marshall(e.getValue(), valueMarshalling))
                ));

                // Now wait for the outcomes and then add to the grand result
                for (Map.Entry<String, Future<Boolean>> e : internalKeyToFutureMap.entrySet()) {
                    e.getValue().get(); // Don't care about the result as it will always be true
                }

                grandResult.putAll(missingValues);
            }

            return grandResult;
        });
    }

    @Override
    public CompletionStage<Map<String, Optional<IdentifiedValue<V>>>> getBulkIdentified(Iterable<String> internalKeys) {
        return perform(() -> {
            if (isEmpty(internalKeys)) {
                return new HashMap<>();
            }

            // There is not equivalent call in Spy Memcached client. So need to do the calls async.
            final VersionedExternalCacheRequestContext cacheContext = ensureCacheContext();

            // De-duplicate the keys, create map on internalKey to the future
            final Map<String, Future<CASValue<Object>>> internalKeyToFuture =
                    StreamSupport.stream(internalKeys.spliterator(), false)
                            .distinct()
                            .collect(Collectors.toMap(
                                    k -> k,
                                    k -> clientSupplier.get().asyncGets(cacheContext.externalEntryKeyFor(k))
                            ));

            return internalKeyToFuture.entrySet().stream().collect(Collectors.toMap(
                    Map.Entry::getKey,
                    e -> identifiedValueFrom(e.getValue(), valueMarshalling)
            ));
        });
    }

    @Override
    public CompletionStage<Boolean> put(String internalKey, V value, PutPolicy policy) {
        return perform(() -> {
            final String externalKey = buildExternalKey(internalKey);
            final byte[] valueBytes = valueMarshalling.getMarshaller().marshallToBytes(requireNonNull(value));

            final Future<Boolean> putOp =
                    putOperationForPolicy(policy, clientSupplier.get(), externalKey, expiryTime(ttlSeconds), valueBytes);

            return putOp.get();
        });
    }

    @Override
    public CompletionStage<Boolean> removeIf(String internalKey, CasIdentifier casId) {
        return perform(() -> {
            final String externalKey = buildExternalKey(internalKey);
            final Future<Boolean> delOp = clientSupplier.get().delete(externalKey, safeExtractId(casId));
            return delOp.get();
        });
    }

    @Override
    public CompletionStage<Boolean> replaceIf(String internalKey, CasIdentifier casId, V newValue) {
        return perform(() -> {
            final String externalKey = buildExternalKey(internalKey);
            final CASResponse casOp = clientSupplier.get().cas(
                    externalKey,
                    safeExtractId(casId),
                    expiryTime(ttlSeconds),
                    valueMarshalling.getMarshaller().marshallToBytes(requireNonNull(newValue)));
            return casOp == CASResponse.OK;
        });
    }

    @Override
    public CompletionStage<Void> remove(Iterable<String> internalKeys) {
        // There is no bulk delete in the api, so need to remove each one async
        return perform(() -> {
            if (isEmpty(internalKeys)) {
                return null;
            }

            // Lodge all the requests for delete
            final List<Future<Boolean>> deleteOps =
                    StreamSupport.stream(internalKeys.spliterator(), false)
                            .map(this::buildExternalKey)
                            .map(k -> clientSupplier.get().delete(k))
                            .collect(Collectors.toList());

            // Now wait for the outcome
            for (Future<Boolean> delOp : deleteOps) {
                delOp.get(); // don't care if succeeded or not
            }

            return null;
        });
    }

    @Override
    public CompletionStage<Void> removeAll() {
        return perform(() -> {
            ensureCacheContext().updateCacheVersion(MemcachedUtils.cacheVersionIncrementer(clientSupplier));
            return null;
        });
    }

    @VisibleForTesting
    void refreshCacheVersion() {
        // Refresh the cacheVersion. Useful if want to get the current state of the external cache in testing.
        ensureCacheContext().updateCacheVersion(MemcachedUtils.cacheVersionSupplier(clientSupplier));
    }

    private String buildExternalKey(String internalKey) throws OperationTimeoutException {
        final VersionedExternalCacheRequestContext cacheContext = ensureCacheContext();
        return cacheContext.externalEntryKeyFor(internalKey);
    }

    protected VersionedExternalCacheRequestContext<V> ensureCacheContext() {
        final RequestContext requestContext = contextSupplier.get();

        return requestContext.computeIfAbsent(this, () -> {
            // Need to build a new context, which involves getting the current cache version, or setting it if it does
            // not exist.
            log.trace("Cache {}: Setting up a new context", name);
            return new VersionedExternalCacheRequestContext<>(
                    keyGenerator, name, requestContext::partitionIdentifier,
                    MemcachedUtils.cacheVersionSupplier(clientSupplier),
                    lockTimeout);
        });
    }

    @Override
    protected Logger getLogger() {
        return log;
    }

    @Override
    protected ExternalCacheException mapException(Exception ex) {
        return MemcachedUtils.mapException(ex);
    }
}
