package com.atlassian.vcache.internal.core;

import com.atlassian.vcache.internal.BegunTransactionalActivityHandler;
import com.atlassian.vcache.internal.RequestContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

import static java.util.Objects.requireNonNull;

/**
 * Default implementation of {@link TransactionControlManager}.
 *
 * @since 1.0.0
 */
public class DefaultTransactionControlManager implements TransactionControlManager {
    private static final Logger log = LoggerFactory.getLogger(DefaultTransactionControlManager.class);

    private final Object transactionControllersKey = new Object();
    private final Object callbackKey = new Object();

    private final Instrumentor instrumentor;
    private final BegunTransactionalActivityHandler begunTransactionalActivityHandler;

    public DefaultTransactionControlManager(Instrumentor instrumentor,
                                            BegunTransactionalActivityHandler begunTransactionalActivityHandler) {
        this.instrumentor = requireNonNull(instrumentor);
        this.begunTransactionalActivityHandler = requireNonNull(begunTransactionalActivityHandler);
    }

    @Override
    public void registerTransactionalExternalCache(RequestContext requestContext,
                                                   String cacheName,
                                                   TransactionControl control) {
        requestContext
                .computeIfAbsent(transactionControllersKey, HashMap::new)
                .computeIfAbsent(cacheName, x -> {
                    log.trace("Registering {}", cacheName);
                    return instrumentor.wrap(control, cacheName);
                });

        invokeCallbackIfNecessary(requestContext);
    }

    @Override
    public void syncAll(RequestContext requestContext) {
        log.trace("Synchronising all caches");

        requestContext.<Map<String, TransactionControl>>get(transactionControllersKey)
                .ifPresent(txControls -> txControls.forEach((cacheName, transactionControl) -> {
                    log.trace("Syncing {}", cacheName);
                    transactionControl.transactionSync();
                }));
        resetShouldInvokeCallback(requestContext);
    }

    @Override
    public Set<String> discardAll(RequestContext requestContext) {
        log.trace("Discarding all caches");
        final Set<String> discardedCacheNames = new HashSet<>();

        requestContext.<Map<String, TransactionControl>>get(transactionControllersKey)
                .ifPresent(txControls -> txControls.forEach((cacheName, transactionControl) -> {
                    log.trace("Discarding {}", cacheName);
                    if (transactionControl.transactionDiscard()) {
                        discardedCacheNames.add(cacheName);
                    }
                }));
        resetShouldInvokeCallback(requestContext);

        return discardedCacheNames;
    }

    private void invokeCallbackIfNecessary(RequestContext requestContext) {
        if (getCallbackInvokedFlag(requestContext).compareAndSet(false, true)) {
            begunTransactionalActivityHandler.onRequest(requestContext);
        }
    }

    private void resetShouldInvokeCallback(RequestContext requestContext) {
        getCallbackInvokedFlag(requestContext).set(false);
    }

    private AtomicBoolean getCallbackInvokedFlag(RequestContext requestContext) {
        return requestContext.computeIfAbsent(callbackKey, () -> new AtomicBoolean(false));
    }
}
