package tech.relaycorp.relaynet.crypto

import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.cms.ContentInfo
import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.jcajce.JcaCertStore
import org.bouncycastle.cert.selector.X509CertificateHolderSelector
import org.bouncycastle.cms.CMSException
import org.bouncycastle.cms.CMSProcessableByteArray
import org.bouncycastle.cms.CMSSignedData
import org.bouncycastle.cms.CMSSignedDataGenerator
import org.bouncycastle.cms.CMSTypedData
import org.bouncycastle.cms.SignerInfoGenerator
import org.bouncycastle.cms.SignerInformation
import org.bouncycastle.cms.jcajce.JcaSignerInfoGeneratorBuilder
import org.bouncycastle.cms.jcajce.JcaSimpleSignerInfoVerifierBuilder
import org.bouncycastle.operator.ContentSigner
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.bouncycastle.operator.jcajce.JcaDigestCalculatorProviderBuilder
import org.bouncycastle.util.CollectionStore
import org.bouncycastle.util.Selector
import tech.relaycorp.relaynet.BC_PROVIDER
import tech.relaycorp.relaynet.HashingAlgorithm
import tech.relaycorp.relaynet.wrappers.x509.Certificate
import java.io.IOException
import java.security.PrivateKey
import java.security.PublicKey

/**
 * Relaynet-specific, CMS SignedData representation.
 */
class SignedData(internal val bcSignedData: CMSSignedData) {
    /**
     * The signed plaintext, if it was encapsulated.
     */
    val plaintext: ByteArray? by lazy { bcSignedData.signedContent?.content as ByteArray? }

    /**
     * The signer's certificate, if it was encapsulated.
     */
    val signerCertificate: Certificate? by lazy {
        val signerInfo = getSignerInfo(bcSignedData)

        // We shouldn't have to force this type cast but this is the only way I could get the code to work and, based on
        // what I found online, that's what others have had to do as well
        @Suppress("UNCHECKED_CAST") val signerCertSelector = X509CertificateHolderSelector(
            signerInfo.sid.issuer,
            signerInfo.sid.serialNumber
        ) as Selector<X509CertificateHolder>

        val signerCertMatches = bcSignedData.certificates.getMatches(signerCertSelector)
        try {
            Certificate(signerCertMatches.first())
        } catch (_: NoSuchElementException) {
            null
        }
    }

    /**
     * Set of encapsulated certificates.
     */
    val certificates: Set<Certificate> by lazy {
        (bcSignedData.certificates as CollectionStore).map { Certificate(it) }.toSet()
    }

    fun serialize(): ByteArray = bcSignedData.encoded

    /**
     * Verify signature.
     *
     * @param expectedPlaintext The plaintext to be verified if none is encapsulated
     * @param signerPublicKey The signer's public key if a corresponding certificate isn't
     *     encapsulated
     */
    @Throws(SignedDataException::class)
    fun verify(expectedPlaintext: ByteArray? = null, signerPublicKey: PublicKey? = null) {
        if (plaintext != null && expectedPlaintext != null) {
            throw SignedDataException(
                "No specific plaintext should be expected because one is already encapsulated"
            )
        }
        val signedPlaintext = plaintext
            ?: expectedPlaintext
            ?: throw SignedDataException("Plaintext should be encapsulated or explicitly set")

        if (signerCertificate != null && signerPublicKey != null) {
            throw SignedDataException(
                "No specific signer certificate should be expected because one is already " +
                    "encapsulated"
            )
        } else if (signerCertificate == null && signerPublicKey == null) {
            throw SignedDataException(
                "Signer certificate should be encapsulated or explicitly set"
            )
        }
        val signedData = CMSSignedData(
            CMSProcessableByteArray(signedPlaintext),
            bcSignedData.toASN1Structure()
        )
        val signerInfo = getSignerInfo(signedData)
        val verifierBuilder = JcaSimpleSignerInfoVerifierBuilder().setProvider(BC_PROVIDER)
        val verifier = if (signerCertificate != null)
            verifierBuilder.build(signerCertificate!!.certificateHolder)
        else
            verifierBuilder.build(signerPublicKey)
        val isValid = try {
            signerInfo.verify(verifier)
        } catch (exc: CMSException) {
            throw SignedDataException("Invalid signature", exc)
        }
        if (!isValid) {
            throw SignedDataException("Invalid signature")
        }
    }

    companion object {
        private val signatureAlgorithmMap = mapOf(
            HashingAlgorithm.SHA256 to "SHA256WITHRSAANDMGF1",
            HashingAlgorithm.SHA384 to "SHA384WITHRSAANDMGF1",
            HashingAlgorithm.SHA512 to "SHA512WITHRSAANDMGF1"
        )

        /**
         * Generate SignedData value with a SignerInfo using an IssuerAndSerialNumber id.
         */
        @JvmStatic
        fun sign(
            plaintext: ByteArray,
            signerPrivateKey: PrivateKey,
            signerCertificate: Certificate,
            encapsulatedCertificates: Set<Certificate> = setOf(),
            hashingAlgorithm: HashingAlgorithm? = null,
            encapsulatePlaintext: Boolean = true
        ): SignedData {
            val contentSigner = makeContentSigner(signerPrivateKey, hashingAlgorithm)
            val signerInfoGenerator = makeSignerInfoGeneratorBuilder().build(
                contentSigner,
                signerCertificate.certificateHolder
            )
            return sign(
                plaintext,
                signerInfoGenerator,
                encapsulatedCertificates,
                encapsulatePlaintext
            )
        }

        /**
         * Generate SignedData value with a SignerInfo using a SubjectKeyIdentifier.
         */
        @JvmStatic
        fun sign(
            plaintext: ByteArray,
            signerPrivateKey: PrivateKey,
            hashingAlgorithm: HashingAlgorithm? = null,
            encapsulatePlaintext: Boolean = true
        ): SignedData {
            val contentSigner = makeContentSigner(signerPrivateKey, hashingAlgorithm)
            val signerInfoGenerator = makeSignerInfoGeneratorBuilder().build(
                contentSigner,
                byteArrayOf()
            )
            return sign(
                plaintext,
                signerInfoGenerator,
                emptySet(),
                encapsulatePlaintext
            )
        }

        private fun sign(
            plaintext: ByteArray,
            signerInfoGenerator: SignerInfoGenerator,
            encapsulatedCertificates: Set<Certificate>,
            encapsulatePlaintext: Boolean
        ): SignedData {
            val signedDataGenerator = CMSSignedDataGenerator()

            signedDataGenerator.addSignerInfoGenerator(signerInfoGenerator)

            val certs = JcaCertStore(encapsulatedCertificates.map { it.certificateHolder })
            signedDataGenerator.addCertificates(certs)

            val plaintextCms: CMSTypedData = CMSProcessableByteArray(plaintext)
            val bcSignedData = signedDataGenerator.generate(plaintextCms, encapsulatePlaintext)
            return SignedData(
                // Work around BC bug that keeps the plaintext encapsulated in the CMSSignedData
                // instance even if it's not encapsulated
                if (encapsulatePlaintext) bcSignedData
                else CMSSignedData(bcSignedData.toASN1Structure())
            )
        }

        private fun makeSignerInfoGeneratorBuilder() = JcaSignerInfoGeneratorBuilder(
            JcaDigestCalculatorProviderBuilder().build()
        )

        private fun makeContentSigner(
            signerPrivateKey: PrivateKey,
            hashingAlgorithm: HashingAlgorithm?
        ): ContentSigner {
            val algorithm = hashingAlgorithm ?: HashingAlgorithm.SHA256
            val signerBuilder =
                JcaContentSignerBuilder(signatureAlgorithmMap[algorithm]).setProvider(BC_PROVIDER)
            return signerBuilder.build(signerPrivateKey)
        }

        @JvmStatic
        fun deserialize(serialization: ByteArray): SignedData {
            if (serialization.isEmpty()) {
                throw SignedDataException("Value cannot be empty")
            }
            val asn1Stream = ASN1InputStream(serialization)
            val asn1Sequence = try {
                asn1Stream.readObject()
            } catch (_: IOException) {
                throw SignedDataException("Value is not DER-encoded")
            }
            val contentInfo = try {
                ContentInfo.getInstance(asn1Sequence)
            } catch (_: IllegalArgumentException) {
                throw SignedDataException("SignedData value is not wrapped in ContentInfo")
            }
            val bcSignedData = try {
                CMSSignedData(contentInfo)
            } catch (_: CMSException) {
                throw SignedDataException("ContentInfo wraps invalid SignedData value")
            }
            return SignedData(bcSignedData)
        }

        private fun getSignerInfo(bcSignedData: CMSSignedData): SignerInformation {
            val signersCount = bcSignedData.signerInfos.size()
            if (signersCount != 1) {
                throw SignedDataException(
                    "SignedData should contain exactly one SignerInfo (got $signersCount)"
                )
            }
            return bcSignedData.signerInfos.first()
        }
    }
}
