package com.atlassian.crowd.directory.authentication;

import com.google.common.cache.LoadingCache;
import com.sun.jersey.api.client.ClientHandler;
import com.sun.jersey.api.client.ClientRequest;
import com.sun.jersey.api.client.ClientResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.ws.rs.core.HttpHeaders;
import java.util.function.Supplier;

/**
 * Adds Azure AD authentication token support to Jersey. Uses a node-local cache containing the most recent token value,
 * which is guaranteed to be refreshed by a single node. Each node manages its own sessions and tokens as it needs.
 * The lifecycle of this cache is tied to the lifecycle of this class, which is in turn tied to the lifecycle of the
 * {@link com.atlassian.crowd.directory.rest.AzureAdRestClient} (generally the same as the lifecycle of
 * {@link com.atlassian.crowd.directory.AzureAdDirectory}
 */
public class AzureAdTokenRefresher {

    public static final String AZURE_AD_TOKEN_CACHE_KEY = "AZURE_AD_TOKEN";
    private static final Logger log = LoggerFactory.getLogger(AzureAdTokenRefresher.class);
    private final LoadingCache<String, String> tokenCache;

    public AzureAdTokenRefresher(final LoadingCache<String, String> tokenCache) {
        this.tokenCache = tokenCache;
    }

    /**
     * Sets the current authentication token in the request, creating one if necessary, and adds it as a request header.
     * If the token has expired a new one will be generated and the request will be retried
     * @param request the
     * @param next a supplier of the next handler
     * @return the ClientResponse returned from the next handler
     */
    public ClientResponse handle(final ClientRequest request, final Supplier<ClientHandler> next) {
        setTokenInRequest(request);
        final ClientResponse response = next.get().handle(request);
        if (response.getClientResponseStatus() == ClientResponse.Status.UNAUTHORIZED) {
            if (log.isDebugEnabled()) {
                log.debug("Got a 401 response from Microsoft Graph, retrying the request. Response body: {}",
                        response.getEntity(String.class));
            }
            tokenCache.invalidate(AZURE_AD_TOKEN_CACHE_KEY);
            setTokenInRequest(request);
            return next.get().handle(request);
        }
        return response;
    }

    private void setTokenInRequest(final ClientRequest cr) {
        final String newAzureAdToken = tokenCache.getUnchecked(AZURE_AD_TOKEN_CACHE_KEY);
        cr.getHeaders().putSingle(HttpHeaders.AUTHORIZATION, newAzureAdToken);
    }
}
