package dev.fitko.fitconnect.core.keys;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.jose.jwk.KeyOperation;
import com.nimbusds.jose.jwk.RSAKey;
import dev.fitko.fitconnect.api.config.ApplicationConfig;
import dev.fitko.fitconnect.api.domain.model.destination.PublicDestination;
import dev.fitko.fitconnect.api.domain.model.jwk.ApiJwk;
import dev.fitko.fitconnect.api.domain.model.jwk.ApiJwkSet;
import dev.fitko.fitconnect.api.domain.validation.ValidationResult;
import dev.fitko.fitconnect.api.exceptions.internal.InvalidKeyException;
import dev.fitko.fitconnect.api.exceptions.internal.RestApiException;
import dev.fitko.fitconnect.api.services.auth.OAuthService;
import dev.fitko.fitconnect.api.services.http.HttpClient;
import dev.fitko.fitconnect.api.services.keys.KeyService;
import dev.fitko.fitconnect.api.services.validation.ValidationService;
import dev.fitko.fitconnect.core.http.HttpHeaders;
import dev.fitko.fitconnect.core.http.MimeTypes;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PublicKeyApiService implements KeyService {
    public static final String DESTINATIONS_KEY_PATH = "/v2/destinations/%s/keys/%s";
    public static final String WELL_KNOWN_KEYS_PATH = "/.well-known/jwks.json";

    private static final Logger LOGGER = LoggerFactory.getLogger(PublicKeyApiService.class);
    private static final ObjectMapper MAPPER = new ObjectMapper();

    private final ApplicationConfig config;
    private final ValidationService validationService;
    private final HttpClient httpClient;
    private final OAuthService authService;

    public PublicKeyApiService(
            final ApplicationConfig config,
            final HttpClient httpClient,
            final OAuthService authService,
            final ValidationService validationService) {
        this.config = config;
        this.httpClient = httpClient;
        this.authService = authService;
        this.validationService = validationService;
    }

    public PublicKeyApiService(
            final ApplicationConfig config, final HttpClient httpClient, final ValidationService validationService) {
        this(config, httpClient, null, validationService);
    }

    @Override
    public RSAKey getPublicEncryptionKey(final PublicDestination destination) {
        final String destinationUrl = config.getSubmissionBaseUrl() + DESTINATIONS_KEY_PATH;
        final ApiJwk publicKey = performRequest(
                destinationUrl,
                ApiJwk.class,
                buildHeaders(true),
                destination.getDestinationId(),
                destination.getEncryptionKid());
        return getValidationEncryptionKey(publicKey);
    }

    @Override
    public RSAKey getDestinationPublicSignatureKey(final UUID destinationId, final String keyId, Date validationDate) {
        final String destinationUrl = config.getSubmissionBaseUrl() + DESTINATIONS_KEY_PATH;
        final ApiJwk signatureKey =
                performRequest(destinationUrl, ApiJwk.class, buildHeaders(true), destinationId, keyId);
        return getValidatedSignatureKey(signatureKey, validationDate);
    }

    @Override
    public RSAKey getSubmissionServicePublicSignatureKey(final String keyId, final Date validationDate) {
        final String submissionServiceUrl = config.getSubmissionBaseUrl() + WELL_KNOWN_KEYS_PATH;
        final ApiJwkSet wellKnownKeys = performRequest(submissionServiceUrl, ApiJwkSet.class, buildHeaders(false));
        return getValidatedSignatureKey(keyId, validationDate, wellKnownKeys, submissionServiceUrl);
    }

    @Override
    public RSAKey getPublicSignatureWellKnownKey(final String keyId, String baseUrl, Date validationDate) {
        final String signatureKeyUrl = baseUrl + WELL_KNOWN_KEYS_PATH;
        final ApiJwkSet wellKnownKeys = performRequest(signatureKeyUrl, ApiJwkSet.class, buildHeaders(false));
        return getValidatedSignatureKey(keyId, validationDate, wellKnownKeys, signatureKeyUrl);
    }

    private RSAKey getValidatedSignatureKey(
            String keyId, Date validationDate, ApiJwkSet wellKnownKeys, String requestUrl) {
        final RSAKey key = getRsaKey(keyId, wellKnownKeys, requestUrl);
        final ValidationResult result = validationService.validatePublicKey(key, validationDate, KeyOperation.VERIFY);
        validateResult(result, "Public signature key is not valid");
        return key;
    }

    private RSAKey getValidationEncryptionKey(ApiJwk publicKey) {
        final RSAKey rsaKey = toRSAKey(publicKey);
        final ValidationResult result = validationService.validatePublicKey(rsaKey, KeyOperation.WRAP_KEY);
        validateResult(result, "Invalid public encryption key");
        return rsaKey;
    }

    private RSAKey getValidatedSignatureKey(ApiJwk publicKey, Date validationDate) {
        final RSAKey rsaKey = toRSAKey(publicKey);
        final ValidationResult result =
                validationService.validatePublicKey(rsaKey, validationDate, KeyOperation.VERIFY);
        validateResult(result, "Public signature key is not valid");
        return rsaKey;
    }

    private RSAKey getRsaKey(String keyId, ApiJwkSet wellKnownKeys, String requestUrl) {
        final Optional<RSAKey> signatureKey = filterKeysById(keyId, wellKnownKeys.getKeys());
        return signatureKey.orElseThrow(
                () -> new InvalidKeyException("Key with id " + keyId + " could not be found at url " + requestUrl));
    }

    private Optional<RSAKey> filterKeysById(final String keyId, final List<ApiJwk> keys) {
        return keys.stream()
                .filter(key -> key.getKid().equals(keyId))
                .map(this::toRSAKey)
                .findFirst();
    }

    private void validateResult(final ValidationResult validationResult, final String message) {
        if (validationResult.hasError()) {
            if (config.isAllowInsecurePublicKey()) {
                LOGGER.warn(message, validationResult.getError());
            } else {
                throw new InvalidKeyException(message, validationResult.getError());
            }
        }
    }

    private RSAKey toRSAKey(final ApiJwk jwk) {
        try {
            return RSAKey.parse(MAPPER.writeValueAsString(jwk));
        } catch (final JsonProcessingException | ParseException e) {
            throw new InvalidKeyException("Key could not be parsed", e);
        }
    }

    private <T> T performRequest(
            final String url, final Class<T> responseType, final Map<String, String> headers, final Object... params) {
        try {
            return httpClient
                    .get(String.format(url, params), headers, responseType)
                    .getBody();
        } catch (final RestApiException e) {
            throw new RestApiException("Request failed", e);
        }
    }

    private Map<String, String> buildHeaders(boolean withAuthorization) {
        final Map<String, String> headers = new HashMap<>();
        headers.put(HttpHeaders.ACCEPT, MimeTypes.APPLICATION_JSON);
        headers.put(HttpHeaders.ACCEPT_CHARSET, StandardCharsets.UTF_8.toString());
        if (withAuthorization) {
            headers.put(
                    HttpHeaders.AUTHORIZATION,
                    "Bearer " + authService.getCurrentToken().getAccessToken());
        }
        return headers;
    }
}
