package com.atlassian.crowd.directory.rest;

import com.atlassian.crowd.directory.authentication.AzureAdRefreshTokenFilter;
import com.atlassian.crowd.directory.authentication.AzureAdTokenRefresher;
import com.atlassian.crowd.directory.authentication.MsGraphApiAuthenticator;
import com.atlassian.crowd.directory.authentication.impl.MsalAuthenticatorFactory;
import com.atlassian.crowd.directory.rest.endpoint.AzureApiUriResolver;
import com.atlassian.crowd.directory.rest.util.IoUtilsWrapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.primitives.Ints;
import com.sun.jersey.api.client.Client;
import com.sun.jersey.api.client.config.DefaultClientConfig;
import org.codehaus.jackson.jaxrs.JacksonJaxbJsonProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Default implementation of {@link AzureAdRestClientFactory}
 */
public class DefaultAzureAdRestClientFactory implements AzureAdRestClientFactory {

    private final MsalAuthenticatorFactory msalAuthenticatorFactory;
    private final IoUtilsWrapper ioUtilsWrapper;
    private static final Logger log = LoggerFactory.getLogger(DefaultAzureAdRestClientFactory.class);

    public DefaultAzureAdRestClientFactory(final MsalAuthenticatorFactory msalAuthenticatorFactory, IoUtilsWrapper ioUtilsWrapper) {
        this.msalAuthenticatorFactory = msalAuthenticatorFactory;
        this.ioUtilsWrapper = ioUtilsWrapper;
    }

    @Override
    public AzureAdRestClient create(final String clientId, final String clientSecret, final String tenantId, final AzureApiUriResolver endpointDataProvider, final long connectionTimeout, final long readTimeout) {
        Preconditions.checkNotNull(Strings.emptyToNull(tenantId),"Tenant ID not specified");
        final Client jerseyClient = createJerseyClient(clientId, clientSecret, tenantId, endpointDataProvider, connectionTimeout, readTimeout);
        return new AzureAdRestClient(jerseyClient, endpointDataProvider, ioUtilsWrapper);
    }

    @VisibleForTesting
    Client createJerseyClient(String clientId, String clientSecret, String tenantId, AzureApiUriResolver azureApiUriResolver, long connectionTimeout, long readTimeout) {
        final DefaultClientConfig config = new DefaultClientConfig();
        config.getSingletons().add(new JacksonJaxbJsonProvider());
        final Client jerseyClient = Client.create(config);
        jerseyClient.setConnectTimeout(loggedSaturatedCast(connectionTimeout, "connection"));
        jerseyClient.setReadTimeout(loggedSaturatedCast(readTimeout, "read"));
        final MsGraphApiAuthenticator msalAuthenticator = msalAuthenticatorFactory.create(clientId, clientSecret, tenantId, azureApiUriResolver);
        jerseyClient.addFilter(createAzureAdTokenFilter(msalAuthenticator));
        return jerseyClient;
    }

    private int loggedSaturatedCast(long valueAsLong, String timeoutType) {
        final int saturatedValueAsInt = Ints.saturatedCast(valueAsLong);
        if (valueAsLong != saturatedValueAsInt) {
            log.debug("Specified value {} for {} timeout cannot be represented as an integer, performing saturated cast to {}",
                    valueAsLong, timeoutType, saturatedValueAsInt);
        }
        return saturatedValueAsInt;
    }

    @Override
    public AzureAdPagingWrapper create(AzureAdRestClient restClient) {
        return new AzureAdPagingWrapper(restClient);
    }

    private AzureAdRefreshTokenFilter createAzureAdTokenFilter(final MsGraphApiAuthenticator msalAuthenticator) {
        return new AzureAdRefreshTokenFilter(createAzureAdTokenRefresher(msalAuthenticator));
    }

    private AzureAdTokenRefresher createAzureAdTokenRefresher(final MsGraphApiAuthenticator msalAuthenticator) {
        return new AzureAdTokenRefresher(
                CacheBuilder.newBuilder().build(new CacheLoader<String, String>() {
                    @Override
                    public String load(final String key) throws Exception {
                        Preconditions.checkArgument(AzureAdTokenRefresher.AZURE_AD_TOKEN_CACHE_KEY.equals(key));
                        return msalAuthenticator.getApiToken().accessToken();
                    }
                })
        );
    }
}
