package org.jfrog.security.ssl;

import org.bouncycastle.asn1.*;
import org.bouncycastle.asn1.x500.X500Name;
import org.bouncycastle.asn1.x500.style.IETFUtils;
import org.bouncycastle.asn1.x509.*;
import org.bouncycastle.cert.CertIOException;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.cert.X509v3CertificateBuilder;
import org.bouncycastle.jce.X509KeyUsage;
import org.bouncycastle.operator.ContentSigner;
import org.bouncycastle.operator.OperatorCreationException;
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
import org.bouncycastle.util.IPAddress;
import org.jfrog.security.util.BCProviderFactory;

import javax.security.auth.x500.X500Principal;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Date;
import java.util.List;

public class SignedCertificateBuilder {
    private X500Principal iss;
    private PrivateKey issPrivateKey;
    private X500Principal sub;
    private BigInteger serialNumber;
    private PublicKey subPublicKey;
    private Long expireIn;
    private int certVersion;
    private boolean useSubForSAN = false;
    private List<GeneralName> sanValues = new ArrayList<>();
    private boolean isCA = false;
    private boolean isTLS = false;

    private static final String SIG_ALG = "SHA256WithRSA";
    private static final String BC_PROVIDER = "BC";
    public static final ASN1ObjectIdentifier CERT_VERSION_OID = (new ASN1ObjectIdentifier("2.5.29.17.1")).intern();
    private static final long ONE_DAY_IN_MS = 24 * 60 * 60 * 1000L;
    private static final long MAX_EXPIRY;

    static {
        Calendar calendar = Calendar.getInstance();
        calendar.set(7000, 0, 1, 0, 0); // We'll need to fix it once the year 7000 will be around the corner...
        MAX_EXPIRY = calendar.getTimeInMillis();
        BCProviderFactory.getProvider(); // Make sure that the factory is initialised
    }


    public static SignedCertificateBuilder builder() {
        return new SignedCertificateBuilder();
    }

    public SignedCertificateBuilder iss(String issuerCN) {
        iss = new X500Principal("CN=" + issuerCN);
        return this;
    }

    public SignedCertificateBuilder iss(X500Principal issuer) {
        iss = issuer;
        return this;
    }

    public SignedCertificateBuilder issPrivateKey(PrivateKey issPrivateKey) {
        this.issPrivateKey = issPrivateKey;
        return this;
    }

    public SignedCertificateBuilder sub(String subject) {
        sub = new X500Principal("CN=" + subject);
        return this;
    }

    public SignedCertificateBuilder sub(X500Principal subject) {
        sub = subject;
        return this;
    }

    public SignedCertificateBuilder serialNumber(BigInteger serialNumber) {
        this.serialNumber = serialNumber;
        return this;
    }

    public SignedCertificateBuilder subPublicKey(PublicKey subPublicKey) {
        this.subPublicKey = subPublicKey;
        return this;
    }

    public SignedCertificateBuilder expireIn(Long expireIn) {
        this.expireIn = expireIn;
        return this;
    }

    public SignedCertificateBuilder certVersion(int certVersion) {
        this.certVersion = certVersion;
        return this;
    }

    public SignedCertificateBuilder isCA(boolean isCA) {
        this.isCA = isCA;
        return this;
    }

    public SignedCertificateBuilder isTLS(boolean isTLS) {
        this.isTLS = isTLS;
        return this;
    }

    public SignedCertificateBuilder sanValue(GeneralName generalName) {
        sanValues.add(generalName);
        return this;
    }

    public SignedCertificateBuilder sanIpOrDnsValue(String ip) {
        if (IPAddress.isValid(ip)) {
            sanValues.add(new GeneralName(GeneralName.iPAddress, ip));
        } else {
            sanValues.add(new GeneralName(GeneralName.dNSName, ip));
        }
        return this;
    }

    public SignedCertificateBuilder useSubForSAN() {
        useSubForSAN = true;
        return this;
    }

    public Certificate build() throws CertificateGenerationException {
        java.security.cert.X509Certificate x509Cert;
        x509Cert = buildX509Certificate();
        try {
            ByteArrayInputStream certStream = new ByteArrayInputStream(x509Cert.getEncoded());
            return CertificateFactory.getInstance("X509").generateCertificate(certStream);
        } catch (CertificateException e) {
            throw new IllegalStateException("Failed to convert X509 certificate.", e);
        }

    }

    public X509Certificate buildX509Certificate() throws CertificateGenerationException {
        try {
            Date startDate = new Date(System.currentTimeMillis() - ONE_DAY_IN_MS);
            Date endDate = getEndDate(expireIn == null ? Long.MAX_VALUE : expireIn);

            X509v3CertificateBuilder builder = getX509v3CertificateBuilder(startDate, endDate);

            GeneralName[] sanContent = getSANContent();
            if (sanContent.length > 0) {
                builder.addExtension(Extension.subjectAlternativeName, false, new GeneralNames(sanContent));
            }

            addCertificateVersion(builder, certVersion);

            return getSignedCertificate(issPrivateKey, builder);
        } catch (Exception e) {
            throw new CertificateGenerationException("Failed to generate signed certificate: " + e.getMessage(), e);
        }
    }

    private Date getEndDate(long expireIn) {
        if (expireIn < 0) {
            throw new IllegalArgumentException("'expire in' must be a positive number");
        }
        long endMillis = System.currentTimeMillis() + expireIn;
        endMillis = (endMillis < 0 || endMillis > MAX_EXPIRY) ? MAX_EXPIRY : endMillis;
        return new Date(endMillis);
    }

    private X509v3CertificateBuilder getX509v3CertificateBuilder(Date startDate, Date endDate) throws IOException {

        X509v3CertificateBuilder x509v3CertificateBuilder = new X509v3CertificateBuilder(
                X500Name.getInstance(iss.getEncoded()),
                serialNumber == null ? BigInteger.ONE : serialNumber,
                startDate,
                endDate,
                X500Name.getInstance(sub.getEncoded()),
                SubjectPublicKeyInfo.getInstance(new ASN1InputStream(subPublicKey.getEncoded()).readObject()));
        if (isCA) {
            x509v3CertificateBuilder
                    .addExtension(Extension.keyUsage, true,
                            new X509KeyUsage(X509KeyUsage.digitalSignature | X509KeyUsage.keyCertSign))
                    .addExtension(Extension.basicConstraints, true, new BasicConstraints(true));
        } else if (isTLS) {
            ASN1EncodableVector purposes = new ASN1EncodableVector();
            purposes.add(KeyPurposeId.id_kp_serverAuth);
            purposes.add(KeyPurposeId.id_kp_clientAuth);
            x509v3CertificateBuilder.addExtension(Extension.extendedKeyUsage, true, new DERSequence(purposes))
                    .addExtension(Extension.keyUsage, true,
                            new X509KeyUsage(X509KeyUsage.digitalSignature | X509KeyUsage.keyEncipherment))
                    .addExtension(Extension.basicConstraints, true, new BasicConstraints(false));
        }
        return x509v3CertificateBuilder;
    }

    private void addCertificateVersion(X509v3CertificateBuilder builder, int version)
            throws CertIOException {

        DERSequence seq = new DERSequence(new ASN1Encodable[]{CERT_VERSION_OID, new ASN1Integer(version)});

        ArrayList<GeneralName> namesList = new ArrayList<>();
        namesList.add(new GeneralName(GeneralName.otherName, seq));
        GeneralNames versionName = GeneralNames
                .getInstance(new DERSequence((GeneralName[]) namesList.toArray(new GeneralName[]{})));

        builder.addExtension(CERT_VERSION_OID, false, versionName);
    }

    //Note: is not idempotent
    private GeneralName[] getSANContent() {
        if (useSubForSAN) {
            X500Name x500Name = X500Name.getInstance(sub.getEncoded());
            ASN1Encodable firstValue = x500Name.getRDNs()[0].getFirst().getValue();
            String name = IETFUtils.valueToString(firstValue);
            sanValues.add(new GeneralName(GeneralName.dNSName, name));
        }

        if (!sanValues.isEmpty()) {
            return sanValues.toArray(new GeneralName[0]);
        }

        return new GeneralName[]{};
    }

    private X509Certificate getSignedCertificate(PrivateKey issuerPrivateKey, X509v3CertificateBuilder builder)
            throws OperatorCreationException, CertificateException, IOException {
        ContentSigner sigGen = new JcaContentSignerBuilder(SIG_ALG).setProvider(BC_PROVIDER).build(issuerPrivateKey);
        X509CertificateHolder certHolder = builder.build(sigGen);

        try (final ByteArrayInputStream is = new ByteArrayInputStream(certHolder.getEncoded())) {
            return (X509Certificate) CertificateFactory.getInstance("X.509").generateCertificate(is);
        }
    }
}
