package com.atlassian.crowd.crypto;

import com.atlassian.crowd.embedded.api.Encryptor;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.util.concurrent.UncheckedExecutionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import java.time.Duration;
import java.util.function.BooleanSupplier;
import java.util.function.UnaryOperator;

/**
 * Wrapper that caches encryption and decryption results.
 */
public class CachedEncryptor implements Encryptor {
    private static final Logger logger = LoggerFactory.getLogger(CachedEncryptor.class);
    private static final String NULL_REPLACEMENT = "";
    private static final String NON_NULL_PREFIX = "p";

    private final LoadingCache<String, String> encryptionCache;
    private final LoadingCache<String, String> decryptionCache;
    private final BooleanSupplier cacheEnabledSupplier;

    private final Encryptor delegate;

    public CachedEncryptor(Encryptor delegate, long maxCacheSize, Duration expireAfterAccess, BooleanSupplier cacheEnabledSupplier) {
        Preconditions.checkArgument(!(delegate instanceof SaltingEncryptor), "SaltingEncryptor should not be cached");
        this.delegate = delegate;
        this.encryptionCache = createCache(maxCacheSize, expireAfterAccess, password -> {
            logger.debug("Encrypted password not found in encryption cache. Encrypting.");
            return delegate.encrypt(password);
        });
        this.decryptionCache = createCache(maxCacheSize, expireAfterAccess, encryptedPassword -> {
            logger.debug("Decrypted password not found in decryption cache. Decrypting.");
            return delegate.decrypt(encryptedPassword);
        });
        this.cacheEnabledSupplier = cacheEnabledSupplier;
    }

    private static LoadingCache<String, String> createCache(long maxCacheSize,
                                                            Duration expireAfterAccess,
                                                            UnaryOperator<String> loader) {
        return CacheBuilder.newBuilder()
                .maximumSize(maxCacheSize)
                .expireAfterAccess(expireAfterAccess)
                .build(new CacheLoader<String, String>() {
                    @Override
                    @Nonnull
                    public String load(@Nonnull String sanitizedKey) {
                        return wrapNull(loader.apply(unwrapNull(sanitizedKey)));
                    }
                });
    }

    @Override
    public String encrypt(String password) {
        if (!cacheEnabledSupplier.getAsBoolean()) {
            return delegate.encrypt(password);
        }
        final String encrypted = get(encryptionCache, password);
        decryptionCache.put(wrapNull(encrypted), wrapNull(password));
        return encrypted;
    }

    @Override
    public String decrypt(String encryptedPassword) {
        if (!cacheEnabledSupplier.getAsBoolean()) {
            return delegate.decrypt(encryptedPassword);
        }
        return get(decryptionCache, encryptedPassword);
    }

    private String get(LoadingCache<String, String> cache, String original) {
        try {
            final String resultSanitized = cache.getUnchecked(wrapNull(original));
            return unwrapNull(resultSanitized);
        } catch (UncheckedExecutionException e) {
            Throwables.throwIfUnchecked(e.getCause());
            throw e;
        }
    }

    private static String wrapNull(final String data) {
        return data == null ? NULL_REPLACEMENT : (NON_NULL_PREFIX + data);
    }

    private static String unwrapNull(final String data) {
        return NULL_REPLACEMENT.equals(data) ? null : data.substring(NON_NULL_PREFIX.length());
    }

    @Override
    public boolean changeEncryptionKey() {
        try {
            return delegate.changeEncryptionKey();
        } finally {
            logger.debug("Clearing the encryption cache.");
            clearCache();
        }
    }

    @VisibleForTesting
    void clearCache() {
        encryptionCache.invalidateAll();
        decryptionCache.invalidateAll();
    }
}
