package com.atlassian.vcache.internal.memcached;

import com.atlassian.marshalling.api.MarshallingPair;
import com.atlassian.vcache.CasIdentifier;
import com.atlassian.vcache.ExternalCacheException;
import com.atlassian.vcache.IdentifiedValue;
import com.atlassian.vcache.PutPolicy;
import com.atlassian.vcache.internal.core.DefaultIdentifiedValue;
import com.atlassian.vcache.internal.core.VCacheCoreUtils;
import net.spy.memcached.CASValue;
import net.spy.memcached.MemcachedClientIF;
import net.spy.memcached.OperationTimeoutException;
import net.spy.memcached.internal.CheckedOperationTimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static com.atlassian.vcache.ExternalCacheException.Reason.TIMEOUT;
import static com.atlassian.vcache.ExternalCacheException.Reason.UNCLASSIFIED_FAILURE;
import static com.atlassian.vcache.internal.core.VCacheCoreUtils.unmarshall;

/**
 * Common utility methods that are specific for the Memcached implementation.
 *
 * @since 1.0.0
 */
class MemcachedUtils {
    private static final Logger log = LoggerFactory.getLogger(MemcachedUtils.class);

    /**
     * The maximum number of seconds offset for a time to live, before TTL must be expressed as offset since 1970.
     */
    private static final int MAX_SECONDS_OFFSET = 60 * 60 * 24 * 30;

    static long safeExtractId(CasIdentifier casId) {
        if (casId instanceof MemcachedCasIdentifier) {
            return ((MemcachedCasIdentifier) casId).getId();
        }

        log.warn("Passed an unknown CasIdentifier instance of class {}.", casId.getClass().getName());
        throw new ExternalCacheException(UNCLASSIFIED_FAILURE);
    }

    static <V> Optional<IdentifiedValue<V>> identifiedValueFrom(Future<CASValue<Object>> op, MarshallingPair<V> valueMarshalling) {
        final CASValue<Object> casValue;
        try {
            casValue = op.get();
        } catch (ExecutionException | InterruptedException ex) {
            throw new ExternalCacheException(UNCLASSIFIED_FAILURE, ex);
        }

        if (casValue == null) {
            return Optional.empty();
        }

        final CasIdentifier identifier = new MemcachedCasIdentifier(casValue.getCas());
        final IdentifiedValue<V> iv = new DefaultIdentifiedValue<>(
                identifier, VCacheCoreUtils.unmarshall((byte[]) casValue.getValue(), valueMarshalling).get());
        return Optional.of(iv);
    }

    static Future<Boolean> putOperationForPolicy(
            PutPolicy policy, MemcachedClientIF client, String externalKey, int defaultTtl, byte[] valueBytes) {
        final Future<Boolean> putOp;
        if (policy == PutPolicy.ADD_ONLY) {
            putOp = client.add(externalKey, defaultTtl, valueBytes);
        } else if (policy == PutPolicy.PUT_ALWAYS) {
            putOp = client.set(externalKey, defaultTtl, valueBytes);
        } else if (policy == PutPolicy.REPLACE_ONLY) {
            putOp = client.replace(externalKey, defaultTtl, valueBytes);
        } else {
            throw new IllegalArgumentException("Unknown put policy: " + policy);
        }

        return putOp;
    }

    static ExternalCacheException mapException(Exception ex) {
        if (ex instanceof OperationTimeoutException || ex instanceof CheckedOperationTimeoutException) {
            return new ExternalCacheException(TIMEOUT, ex);
        } else if (ex instanceof RuntimeException && ex.getCause() instanceof Exception) {
            final Exception cause = (Exception) ex.getCause();
            if (cause instanceof ExternalCacheException) {
                return (ExternalCacheException) cause;
            } else {
                return mapException(cause);
            }
        } else {
            return new ExternalCacheException(UNCLASSIFIED_FAILURE, ex);
        }
    }

    static <V> Map<String, Optional<V>> directGetBulk(
            Set<String> externalKeys, Supplier<MemcachedClientIF> clientSupplier, MarshallingPair<V> valueMarshalling) {
        // Returns map of keys that contain values, so need to handle the missing ones
        final Map<String, Object> haveValues = clientSupplier.get().getBulk(externalKeys);

        return externalKeys.stream().collect(Collectors.toMap(
                k -> k,
                k -> unmarshall((byte[]) haveValues.get(k), valueMarshalling)));
    }

    private static long obtainCacheVersion(Supplier<MemcachedClientIF> clientSupplier, String externalCacheVersionKey) {
        // Incrementing by 0, to get the current value
        return clientSupplier.get().incr(externalCacheVersionKey, 0, 1);
    }

    private static long incrementCacheVersion(Supplier<MemcachedClientIF> clientSupplier, String externalCacheVersionKey) {
        return clientSupplier.get().incr(externalCacheVersionKey, 1, 1);
    }

    static Function<String, Long> cacheVersionIncrementer(Supplier<MemcachedClientIF> clientSupplier) {
        return externalCacheVersionKey -> incrementCacheVersion(clientSupplier, externalCacheVersionKey);
    }

    static Function<String, Long> cacheVersionSupplier(Supplier<MemcachedClientIF> clientSupplier) {
        return externalCacheVersionKey -> obtainCacheVersion(clientSupplier, externalCacheVersionKey);
    }

    /**
     * Calculates the expiry time for a Memcached entry, taking into account the
     * logic required to handle when the seconds are longer than 30 days.
     *
     * @param seconds the period to expiry, in seconds
     * @return a Memcached compliant expiry time
     */
    static int expiryTime(int seconds) {
        if (seconds < MAX_SECONDS_OFFSET) {
            return seconds;
        }

        final int currentTime = (int) (System.currentTimeMillis() / 1_000);
        return currentTime + seconds;
    }
}
