package dev.fitko.fitconnect.core.routing;

import static java.util.Collections.emptyList;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import dev.fitko.fitconnect.api.domain.model.route.Route;
import dev.fitko.fitconnect.api.domain.validation.ValidationResult;
import dev.fitko.fitconnect.api.exceptions.internal.InvalidKeyException;
import dev.fitko.fitconnect.api.exceptions.internal.RestApiException;
import dev.fitko.fitconnect.api.exceptions.internal.ValidationException;
import dev.fitko.fitconnect.api.services.keys.KeyService;
import dev.fitko.fitconnect.api.services.routing.RoutingVerificationService;
import dev.fitko.fitconnect.api.services.validation.ValidationService;
import java.text.ParseException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RouteVerifier implements RoutingVerificationService {

    private static final Logger LOGGER = LoggerFactory.getLogger(RouteVerifier.class);

    private final KeyService keyService;
    private final ValidationService validationService;
    private static final List<String> VALID_KEY_HOSTS = List.of(".fit-connect.fitko.net", ".fit-connect.fitko.dev");

    public RouteVerifier(final KeyService keyService, final ValidationService validationService) {
        this.keyService = keyService;
        this.validationService = validationService;
    }

    @Override
    public List<Route> validateRouteDestinations(
            final List<Route> routes, final String requestedServiceIdentifier, final String requestedRegion) {
        return routes.stream()
                .filter(route -> isValidRoute(route, requestedServiceIdentifier, requestedRegion))
                .collect(Collectors.toList());
    }

    private boolean isValidRoute(
            final Route route, final String requestedServiceIdentifier, final String requestedRegion) {
        try {
            final SignedJWT signature = SignedJWT.parse(route.getDestinationSignature());
            final JWSHeader header = signature.getHeader();
            final JWTClaimsSet claims = signature.getJWTClaimsSet();

            checkHeaderAlgorithm(header);
            validatePayloadSchema(claims);
            checkExpectedServices(claims, requestedServiceIdentifier, requestedRegion);
            validateAgainstPublicKey(signature, header.getKeyID());

            return true;
        } catch (final ValidationException e) {
            LOGGER.warn("Route validation failed for destination {}: {}", route.getDestinationId(), e.getMessage());
            return false;
        } catch (final InvalidKeyException e) {
            LOGGER.warn(
                    "Route validation failed for destination {}: Public signature key is invalid: {}",
                    route.getDestinationId(),
                    e.getMessage());
            return false;
        } catch (final RestApiException e) {
            LOGGER.warn(
                    "Route validation failed for destination {}: Could not retrieve public signature key: {}",
                    route.getDestinationId(),
                    e.getMessage());
            return false;
        } catch (final ParseException | JOSEException e) {
            LOGGER.warn(
                    "Route validation failed for destination {}: Signature processing failed: {}",
                    route.getDestinationId(),
                    e.getMessage());
            return false;
        }
    }

    private void validatePayloadSchema(final JWTClaimsSet claims) {
        final ValidationResult validationResult = validationService.validateDestinationSchema(claims.toJSONObject());
        if (validationResult.hasError()) {
            throw new ValidationException(validationResult.getError().getMessage(), validationResult.getError());
        }
    }

    private void validateAgainstPublicKey(final SignedJWT signature, final String keyId)
            throws JOSEException, ParseException {
        final RSAKey portalPublicKey = getSignatureValidationKey(signature, keyId);
        if (!signature.verify(new RSASSAVerifier(portalPublicKey))) {
            throw new ValidationException("Invalid destination signature for public key id " + keyId);
        }
    }

    private RSAKey getSignatureValidationKey(SignedJWT signature, String keyId) throws ParseException {
        final String issuer = signature.getJWTClaimsSet().getIssuer();
        if (VALID_KEY_HOSTS.stream().anyMatch(issuer::endsWith)) {
            return keyService.getPublicSignatureWellKnownKey(
                    keyId, issuer, signature.getJWTClaimsSet().getIssueTime());
        }
        throw new ValidationException("Requested signature validation key url '"
                + issuer
                + "' is no FIT-Connect host (.fit-connect.fitko.net or .fit-connect.fitko.dev)");
    }

    private void checkExpectedServices(
            final JWTClaimsSet claims, final String requestedServiceIdentifier, final String requestedRegion)
            throws ParseException {

        final List<RouteService> services = claims.getListClaim("services").stream()
                .map(service -> new RouteService((Map<String, List<String>>) service))
                .collect(Collectors.toList());

        final var serviceId = getIdFromIdentifier(requestedServiceIdentifier);
        if (services.stream().noneMatch(service -> service.hasMatchingService(serviceId))) {
            throw new ValidationException("Requested service identifier '"
                    + requestedServiceIdentifier
                    + "' is not supported by any of the destinations services");
        }

        // check combination of service and region - ars can be null if the requested region is an
        // areaId or ags
        if (requestedRegion != null) {
            final var regionId = getIdFromIdentifier(requestedRegion);
            if (services.stream().noneMatch(service -> service.hasMatchingRegionAndService(regionId, serviceId))) {
                throw new ValidationException("Requested region '"
                        + requestedRegion
                        + "' does not match any service provided by the destination");
            }
        }
    }

    private static String getIdFromIdentifier(final String identifier) {
        if (isNumericId(identifier)) {
            return identifier;
        }
        return Arrays.stream(identifier.split(":"))
                .reduce((first, second) -> second)
                .orElse(null);
    }

    private static boolean isNumericId(final String identifier) {
        return Pattern.compile("\\d+").matcher(identifier).matches();
    }

    private void checkHeaderAlgorithm(final JWSHeader header) {
        if (!header.getAlgorithm().equals(JWSAlgorithm.PS512)) {
            throw new ValidationException("Algorithm in signature header is not " + JWSAlgorithm.PS512);
        }
    }

    static class RouteService {
        private final List<String> regionIds;
        private final List<String> serviceIds;

        protected RouteService(final Map<String, List<String>> service) {
            regionIds = service.getOrDefault("gebietIDs", emptyList()).stream()
                    .map(RouteVerifier::getIdFromIdentifier)
                    .collect(Collectors.toList());
            serviceIds = service.getOrDefault("leistungIDs", emptyList()).stream()
                    .map(RouteVerifier::getIdFromIdentifier)
                    .collect(Collectors.toList());
        }

        public boolean hasMatchingRegionAndService(final String regionId, final String serviceId) {
            return hasMatchingRegion(regionId) && hasMatchingService(serviceId);
        }

        public boolean hasMatchingRegion(final String regionId) {
            return regionIds.stream().anyMatch(regionId::contains);
        }

        public boolean hasMatchingService(final String serviceId) {
            return serviceIds.contains(serviceId);
        }
    }
}
