package com.atlassian.asap.core.keys.publickey;

import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.PublicKey;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import com.atlassian.asap.api.exception.CannotRetrieveKeyException;
import com.atlassian.asap.core.exception.PublicKeyNotFoundException;
import com.atlassian.asap.core.exception.PublicKeyRetrievalException;
import com.atlassian.asap.core.keys.PemReader;
import com.atlassian.asap.core.validator.ValidatedKeyId;

import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
import org.apache.http.client.HttpClient;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.utils.HttpClientUtils;
import org.apache.http.entity.ContentType;
import org.apache.http.impl.client.DefaultRedirectStrategy;
import org.apache.http.impl.client.cache.CacheConfig;
import org.apache.http.impl.client.cache.CachingHttpClients;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static com.google.common.base.Preconditions.checkArgument;

/**
 * Reads public keys from web servers using the HTTPS protocol.
 */
public class HttpPublicKeyProvider implements com.atlassian.asap.core.keys.KeyProvider<PublicKey>
{
    /**
     * The default max connections per route. Note this can still be overridden using http client system property.
     */
    static final int DEFAULT_MAX_CONNECTIONS = 20;
    static final String PEM_MIME_TYPE = "application/x-pem-file";
    static final String ACCEPT_HEADER_VALUE = PEM_MIME_TYPE;

    private static final Logger logger = LoggerFactory.getLogger(HttpPublicKeyProvider.class);

    private final HttpClient httpClient;
    private final PemReader pemReader;
    private final URI baseUrl;

    /**
     * Create a new {@link HttpPublicKeyProvider} instance.
     *
     * @param baseUrl the base url of the public key server
     * @param httpClient the http client to use for communicating with the public key server
     * @param pemReader the pem key reader to use for reading public keys in pem format
     */
    public HttpPublicKeyProvider(URI baseUrl, HttpClient httpClient, PemReader pemReader)
    {
        Objects.requireNonNull(baseUrl, "Base URL cannot be null");
        checkArgument(baseUrl.isAbsolute(), "Base URL must be absolute"); // implies that scheme != null
        checkArgument("https".equals(baseUrl.getScheme()), "Base URL must have https scheme");
        checkArgument(StringUtils.endsWith(baseUrl.toString(), "/"), "Base URL does not end with trailing slash: " + baseUrl);

        this.baseUrl = baseUrl;
        this.httpClient = Objects.requireNonNull(httpClient);
        this.pemReader = Objects.requireNonNull(pemReader);
    }

    @Override
    public PublicKey getKey(ValidatedKeyId validatedKeyId) throws CannotRetrieveKeyException
    {
        URI keyUrl = baseUrl.resolve(validatedKeyId.getKeyId());
        HttpGet httpGet = new HttpGet(keyUrl);
        httpGet.setHeader(HttpHeaders.ACCEPT, ACCEPT_HEADER_VALUE);
        logger.debug("Fetching public key {}", keyUrl);

        HttpResponse response = null;
        try
        {
            response = httpClient.execute(httpGet);
            int statusCode = response.getStatusLine().getStatusCode();
            switch (statusCode)
            {
                case HttpStatus.SC_OK:
                    HttpEntity entity = response.getEntity();
                    if (entity != null)
                    {
                        Charset charset = Optional.ofNullable(ContentType.getOrDefault(entity).getCharset()).orElse(StandardCharsets.US_ASCII);
                        try (InputStreamReader reader = new InputStreamReader(entity.getContent(), charset))
                        {
                            String mimeType = ContentType.get(entity).getMimeType();
                            if (!mimeType.equals(PEM_MIME_TYPE))
                            {
                                logger.error("Rejecting public key due to Content type {} when retrieving {}. Public Keys must have Content type of {}.",
                                    mimeType, keyUrl, PEM_MIME_TYPE);
                                throw new PublicKeyRetrievalException("Unexpected public key MIME type");
                            } 
                            else
                            {
                                return pemReader.readPublicKey(reader);
                            }
                        }
                    }
                    else
                    {
                        logger.error("Unexpected empty HTTP response when trying to retrieve public key URL {}", keyUrl);
                        throw new PublicKeyRetrievalException("Unexpected empty response");
                    }
                case HttpStatus.SC_NOT_FOUND:
                    // log at debug level because this can be caused by invalid input
                    logger.debug("Public key URL {} returned 404 NOT FOUND", keyUrl);
                    throw new PublicKeyNotFoundException("Encountered 404 NOT FOUND for public key: " + keyUrl);
                default:
                    logger.error("Unexpected HTTP status code {} when trying to retrieve public key URL {}",
                            statusCode, keyUrl);
                    throw new PublicKeyRetrievalException("Unexpected status code");
            }
        }
        catch (IOException e)
        {
            logger.error("A problem occurred when trying to retrieve public key from URL {}", keyUrl, e);
            throw new PublicKeyRetrievalException("Error reading public key from HTTPS key repository");
        }
        finally
        {
            HttpClientUtils.closeQuietly(response);
        }
    }

    static HttpClient defaultHttpClient()
    {
        PoolingHttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager();
        connectionManager.setDefaultMaxPerRoute(DEFAULT_MAX_CONNECTIONS);
        connectionManager.setMaxTotal(DEFAULT_MAX_CONNECTIONS);

        RequestConfig.Builder requestConfigBuilder = RequestConfig.custom();
        requestConfigBuilder.setConnectTimeout(((int) TimeUnit.SECONDS.toMillis(5)));
        requestConfigBuilder.setSocketTimeout((int) TimeUnit.SECONDS.toMillis(10));

        CacheConfig cacheConfig = CacheConfig.custom()
                .setMaxCacheEntries(128)
                .setMaxObjectSize(2048) // keys (.pem) are small
                .setHeuristicCachingEnabled(false)
                .setSharedCache(false)
                .setAsynchronousWorkersMax(2)
                .build();

        return CachingHttpClients.custom()
                .setCacheConfig(cacheConfig)
                .setDefaultRequestConfig(requestConfigBuilder.build())
                .setConnectionManager(connectionManager)
                .useSystemProperties()
                .setRedirectStrategy(DefaultRedirectStrategy.INSTANCE)
                .build();
    }
}
