/*
 * Decompiled with CFR 0.152.
 */
package se.litsec.swedisheid.opensaml.saml2.signservice;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory;
import com.nimbusds.jose.proc.JWSVerifierFactory;
import com.nimbusds.jwt.SignedJWT;
import java.io.IOException;
import java.nio.charset.Charset;
import java.security.Key;
import java.security.cert.X509Certificate;
import java.text.ParseException;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import net.shibboleth.utilities.java.support.resolver.ResolverException;
import org.joda.time.DateTime;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.AuthnStatement;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.security.credential.UsageType;
import org.opensaml.security.x509.X509Credential;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import se.litsec.opensaml.saml2.attribute.AttributeUtils;
import se.litsec.opensaml.saml2.metadata.MetadataUtils;
import se.litsec.opensaml.saml2.metadata.provider.MetadataProvider;
import se.litsec.opensaml.saml2.metadata.provider.StaticMetadataProvider;
import se.litsec.swedisheid.opensaml.saml2.signservice.SADValidationException;
import se.litsec.swedisheid.opensaml.saml2.signservice.sap.SAD;
import se.litsec.swedisheid.opensaml.saml2.signservice.sap.SADRequest;

public class SADParser {
    private SADParser() {
    }

    public static SAD parse(String sadJwt) throws IOException {
        try {
            SignedJWT signedJwt = SignedJWT.parse((String)sadJwt);
            String payload = signedJwt.getPayload().toBase64URL().toString();
            return SAD.fromJson(new String(Base64.getUrlDecoder().decode(payload), Charset.forName("UTF-8")));
        }
        catch (ParseException e) {
            throw new IOException(e);
        }
    }

    public static SADValidator getValidator(X509Certificate ... validationCertificates) {
        return new SADValidator(validationCertificates);
    }

    public static SADValidator getValidator(MetadataProvider metadataProvider) {
        return new SADValidator(metadataProvider);
    }

    public static SADValidator getValidator(EntityDescriptor idpMetadata) {
        return new SADValidator(idpMetadata);
    }

    public static class SADValidator {
        private Logger logger = LoggerFactory.getLogger(SADValidator.class);
        private List<X509Certificate> validationCertificates;
        private MetadataProvider metadataProvider;
        private static final JWSVerifierFactory verifierFactory = new DefaultJWSVerifierFactory();

        public SADValidator(X509Certificate ... certificates) {
            this.validationCertificates = Arrays.asList(certificates);
        }

        public SADValidator(MetadataProvider metadataProvider) {
            this.metadataProvider = metadataProvider;
        }

        public SADValidator(EntityDescriptor idpMetadata) {
            try {
                this.metadataProvider = new StaticMetadataProvider(idpMetadata);
            }
            catch (MarshallingException e) {
                throw new SecurityException("Invalid IdP metadata", e);
            }
        }

        public SAD validate(AuthnRequest authnRequest, Assertion assertion) throws SADValidationException, IllegalArgumentException {
            String msg;
            SAD sad;
            SignedJWT signedJwt;
            long now = System.currentTimeMillis() / 1000L;
            SADRequest sadRequest = null;
            if (authnRequest.getExtensions() != null) {
                sadRequest = authnRequest.getExtensions().getUnknownXMLObjects().stream().filter(SADRequest.class::isInstance).map(SADRequest.class::cast).findFirst().orElse(null);
            }
            if (sadRequest == null) {
                String msg2 = String.format("AuthnRequest '%s' does not contain a SADRequest", authnRequest.getID());
                this.logger.info(msg2);
                throw new IllegalArgumentException(msg2);
            }
            if (assertion.getAttributeStatements().isEmpty()) {
                String msg3 = String.format("Assertion '%s' does not contain any attributes (and thus no SAD)", assertion.getID());
                this.logger.info(msg3);
                throw new SADValidationException(SADValidationException.ErrorCode.NO_SAD_ATTRIBUTE, msg3);
            }
            List attributes = ((AttributeStatement)assertion.getAttributeStatements().get(0)).getAttributes();
            Attribute sadAttribute = AttributeUtils.getAttribute((String)"urn:oid:1.2.752.201.3.12", (List)attributes).orElse(null);
            if (sadAttribute == null) {
                String msg4 = String.format("Assertion '%s' does not contain a SAD attribute", assertion.getID());
                this.logger.info(msg4);
                throw new SADValidationException(SADValidationException.ErrorCode.NO_SAD_ATTRIBUTE, msg4);
            }
            try {
                signedJwt = SignedJWT.parse((String)AttributeUtils.getAttributeStringValue((Attribute)sadAttribute));
                String payload = signedJwt.getPayload().toBase64URL().toString();
                sad = SAD.fromJson(new String(Base64.getUrlDecoder().decode(payload), Charset.forName("UTF-8")));
            }
            catch (IOException | ParseException e) {
                throw new SADValidationException(SADValidationException.ErrorCode.JWT_PARSE_ERROR, "Failed to parse SAD JWT", e);
            }
            if (sad.getSeElnSadext() == null) {
                msg = "seElnSadext extension claims are missing from SAD";
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, msg);
            }
            if (sad.getSeElnSadext().getAttributeName() == null) {
                msg = "SAD does not contain the attribute name (attr) for the subject";
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, msg);
            }
            Attribute subjectAttribute = AttributeUtils.getAttribute((String)sad.getSeElnSadext().getAttributeName(), (List)attributes).orElse(null);
            if (subjectAttribute == null) {
                String msg5 = String.format("Assertion '%s' does not contain a '%s' attribute - this is listed as the subject attribute in the SAD", assertion.getID(), sad.getSeElnSadext().getAttributeName());
                this.logger.info(msg5);
                throw new SADValidationException(SADValidationException.ErrorCode.MISSING_SUBJECT_ATTRIBUTE, msg5);
            }
            String loa = SADValidator.getLoa(assertion);
            if (loa == null) {
                String msg6 = String.format("Assertion '%s' does not contain a LoA URI", assertion.getID());
                this.logger.error(msg6);
                throw new IllegalArgumentException(msg6);
            }
            if (sadRequest.getDocCount() == null) {
                throw new IllegalArgumentException("Bad SADRequest - missing DocCount");
            }
            return this.validate(signedJwt, sad, now, assertion.getIssuer().getValue(), sadRequest.getRequesterID(), AttributeUtils.getAttributeStringValue((Attribute)subjectAttribute), loa, sadRequest.getID(), sadRequest.getDocCount(), sadRequest.getSignRequestID());
        }

        public SAD validate(String sadJwt, String idpEntityID, String expectedRecipientEntityID, String expectedSubject, String expectedLoa, String sadRequestID, int expectedNoDocs, String signRequestID) throws SADValidationException {
            long now = System.currentTimeMillis() / 1000L;
            try {
                SignedJWT signedJwt = SignedJWT.parse((String)sadJwt);
                String payload = signedJwt.getPayload().toBase64URL().toString();
                SAD sad = SAD.fromJson(new String(Base64.getUrlDecoder().decode(payload), Charset.forName("UTF-8")));
                return this.validate(signedJwt, sad, now, idpEntityID, expectedRecipientEntityID, expectedSubject, expectedLoa, sadRequestID, expectedNoDocs, signRequestID);
            }
            catch (IOException | ParseException e) {
                throw new SADValidationException(SADValidationException.ErrorCode.JWT_PARSE_ERROR, "Failed to parse SAD JWT", e);
            }
        }

        private SAD validate(SignedJWT signedJwt, SAD sad, long now, String idpEntityID, String expectedRecipientEntityID, String expectedSubject, String expectedLoa, String sadRequestID, int expectedNoDocs, String signRequestID) throws SADValidationException {
            this.verifyJwtSignature(signedJwt, idpEntityID);
            if (sad.getJwtId() == null || sad.getJwtId().isEmpty()) {
                String msg = "Invalid SAD JWT - jti is missing";
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, msg);
            }
            if (!Objects.equals(idpEntityID, sad.getIssuer())) {
                String msg = String.format("SAD contains issuer '%s' - expected '%s'", sad.getIssuer(), idpEntityID);
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_ISSUER, msg);
            }
            if (!Objects.equals(expectedRecipientEntityID, sad.getAudience())) {
                String msg = String.format("SAD contains audience '%s' - expected '%s'", sad.getAudience(), expectedRecipientEntityID);
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_AUDIENCE, msg);
            }
            if (sad.getExpiry() == null || sad.getIssuedAt() == null) {
                String msg = "SAD is missing 'exp' and/or 'iat' - Invalid SAD";
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, msg);
            }
            if ((long)sad.getExpiry().intValue() < now) {
                String msg = String.format("SAD has expired - expiration: '%s', current time: '%s'", sad.getExpiryDateTime(), new DateTime(now * 1000L));
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.SAD_EXPIRED, msg);
            }
            if ((long)sad.getIssuedAt().intValue() > now) {
                String msg = String.format("SAD is not yet valid - issue-time: '%s', current time: '%s'", sad.getIssuedAtDateTime(), new DateTime(now * 1000L));
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, msg);
            }
            if (!Objects.equals(expectedSubject, sad.getSubject())) {
                String msg = String.format("SAD contains subject '%s' - expected '%s'", sad.getSubject(), expectedSubject);
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_SUBJECT, msg);
            }
            if (sad.getSeElnSadext() == null) {
                String msg = "seElnSadext extension claims are missing from SAD";
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, msg);
            }
            if (!Objects.equals(sadRequestID, sad.getSeElnSadext().getInResponseTo())) {
                String msg = String.format("SAD contains in-response-to (irt) '%s' - expected SAD to belong to SADRequest with ID '%s'", sad.getSeElnSadext().getInResponseTo(), sadRequestID);
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_IRT, msg);
            }
            if (!Objects.equals(expectedLoa, sad.getSeElnSadext().getLoa())) {
                String msg = String.format("SAD contains LoA '%s' - expected '%s'", sad.getSeElnSadext().getLoa(), expectedLoa);
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_LOA, msg);
            }
            if (!Objects.equals(expectedNoDocs, sad.getSeElnSadext().getNumberOfDocuments())) {
                String msg = String.format("SAD indicated '%s' number of documents - expected '%d'", sad.getSeElnSadext().getNumberOfDocuments(), expectedNoDocs);
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_DOCS, msg);
            }
            if (!Objects.equals(signRequestID, sad.getSeElnSadext().getRequestID())) {
                String msg = String.format("SAD contains SignRequest ID (reqid) '%s' - expected '%s'", sad.getSeElnSadext().getRequestID(), signRequestID);
                this.logger.info(msg);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_SIGNREQUESTID, msg);
            }
            this.logger.debug("SAD with ID '{}' was successfully validated", (Object)sad.getJwtId());
            return sad;
        }

        public void verifyJwtSignature(String sadJwt, String idpEntityID) throws SADValidationException {
            try {
                this.verifyJwtSignature(SignedJWT.parse((String)sadJwt), idpEntityID);
            }
            catch (ParseException e) {
                throw new SADValidationException(SADValidationException.ErrorCode.JWT_PARSE_ERROR, "Failed to parse SAD JWT", e);
            }
        }

        private void verifyJwtSignature(SignedJWT signedJwt, String idpEntityID) throws SADValidationException {
            try {
                List<X509Certificate> idpCerts = this.getValidationCertificates(idpEntityID);
                if (idpCerts.isEmpty()) {
                    throw new SADValidationException(SADValidationException.ErrorCode.SIGNATURE_VALIDATION_ERROR, "No suitable IdP signature certificate was found - can not verify SAD JWT signature");
                }
                this.logger.debug("Verifying SAD JWT signature. Will try {} IdP key(s) ...", (Object)idpCerts.size());
                boolean verificationSuccess = false;
                for (X509Certificate idpCert : idpCerts) {
                    try {
                        JWSVerifier verifier = verifierFactory.createJWSVerifier(signedJwt.getHeader(), (Key)idpCert.getPublicKey());
                        if (!verifier.verify(signedJwt.getHeader(), signedJwt.getSigningInput(), signedJwt.getSignature())) continue;
                        this.logger.debug("SAD JWT signature successfully verified");
                        verificationSuccess = true;
                        break;
                    }
                    catch (JOSEException e) {
                        this.logger.debug("Failed to perform signature validation of SAD JWT - {}", (Object)e.getMessage());
                        this.logger.trace("", (Throwable)e);
                    }
                }
                if (!verificationSuccess) {
                    throw new SADValidationException(SADValidationException.ErrorCode.SIGNATURE_VALIDATION_ERROR, "Signature on SAD JWT could not be validated using any of the IdP certificates found");
                }
            }
            catch (ResolverException e) {
                throw new SADValidationException(SADValidationException.ErrorCode.SIGNATURE_VALIDATION_ERROR, "Failed to find validation certificate", e);
            }
        }

        private List<X509Certificate> getValidationCertificates(String idpEntityID) throws ResolverException {
            if (this.validationCertificates != null && !this.validationCertificates.isEmpty()) {
                return this.validationCertificates;
            }
            if (this.metadataProvider != null) {
                Optional metadata = this.metadataProvider.getEntityDescriptor(idpEntityID);
                if (!metadata.isPresent()) {
                    this.logger.warn("No metadata found for IdP '{}' - cannot find key to use when verifying SAD JWT signature", (Object)idpEntityID);
                    return Collections.emptyList();
                }
                List creds = MetadataUtils.getMetadataCertificates((EntityDescriptor)((EntityDescriptor)metadata.get()), (UsageType)UsageType.SIGNING);
                return creds.stream().map(X509Credential::getEntityCertificate).collect(Collectors.toList());
            }
            return Collections.emptyList();
        }

        private static String getLoa(Assertion assertion) {
            try {
                return ((AuthnStatement)assertion.getAuthnStatements().get(0)).getAuthnContext().getAuthnContextClassRef().getAuthnContextClassRef();
            }
            catch (Exception e) {
                return null;
            }
        }
    }
}

