/*
 * Decompiled with CFR 0.152.
 */
package io.fusionauth.jwks;

import io.fusionauth.der.DerDecodingException;
import io.fusionauth.der.DerInputStream;
import io.fusionauth.der.DerValue;
import io.fusionauth.jwks.JSONWebKeyBuilderException;
import io.fusionauth.jwks.JWKUtils;
import io.fusionauth.jwks.domain.JSONWebKey;
import io.fusionauth.jwt.JWTUtils;
import io.fusionauth.jwt.domain.Algorithm;
import io.fusionauth.jwt.domain.KeyType;
import io.fusionauth.pem.domain.PEM;
import io.fusionauth.security.KeyUtils;
import java.io.IOException;
import java.math.BigInteger;
import java.security.Key;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.security.interfaces.ECKey;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.EdECPrivateKey;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.util.Base64;
import java.util.Collections;
import java.util.Objects;

public class JSONWebKeyBuilder {
    public JSONWebKey build(String encodedPEM) {
        Objects.requireNonNull(encodedPEM);
        PEM pem = PEM.decode(encodedPEM);
        if (pem.privateKey != null) {
            return this.build(pem.privateKey);
        }
        if (pem.certificate != null) {
            return this.build(pem.certificate);
        }
        if (pem.publicKey != null) {
            return this.build(pem.publicKey);
        }
        throw new JSONWebKeyBuilderException("The provided PEM did not contain a public or private key.");
    }

    public JSONWebKey build(PrivateKey privateKey) {
        RSAPrivateKey rsaPrivateKey;
        Objects.requireNonNull(privateKey);
        JSONWebKey key = new JSONWebKey();
        key.kty = this.getKeyType(privateKey);
        key.use = "sig";
        if (privateKey instanceof RSAPrivateKey) {
            rsaPrivateKey = (RSAPrivateKey)privateKey;
            key.n = JWKUtils.base64EncodeUint(rsaPrivateKey.getModulus());
            key.d = JWKUtils.base64EncodeUint(rsaPrivateKey.getPrivateExponent());
        }
        if (privateKey instanceof RSAPrivateCrtKey) {
            rsaPrivateKey = (RSAPrivateCrtKey)privateKey;
            key.e = JWKUtils.base64EncodeUint(rsaPrivateKey.getPublicExponent());
            key.p = JWKUtils.base64EncodeUint(rsaPrivateKey.getPrimeP());
            key.q = JWKUtils.base64EncodeUint(rsaPrivateKey.getPrimeQ());
            key.qi = JWKUtils.base64EncodeUint(rsaPrivateKey.getCrtCoefficient());
            BigInteger dp = rsaPrivateKey.getPrivateExponent().mod(rsaPrivateKey.getPrimeP().subtract(BigInteger.valueOf(1L)));
            Object dq = rsaPrivateKey.getPrivateExponent().mod(rsaPrivateKey.getPrimeQ().subtract(BigInteger.valueOf(1L)));
            key.dp = JWKUtils.base64EncodeUint(dp);
            key.dq = JWKUtils.base64EncodeUint((BigInteger)dq);
        }
        if (privateKey instanceof ECPrivateKey) {
            ECPrivateKey ecPrivateKey = (ECPrivateKey)privateKey;
            key.crv = this.getCurveOID(privateKey);
            if (key.crv != null) {
                switch (key.crv) {
                    case "P-256": {
                        key.alg = Algorithm.ES256;
                        break;
                    }
                    case "P-384": {
                        key.alg = Algorithm.ES384;
                        break;
                    }
                    case "P-521": {
                        key.alg = Algorithm.ES512;
                    }
                }
            }
            int byteLength = this.getCoordinateLength(ecPrivateKey);
            key.d = JWKUtils.base64EncodeUint(ecPrivateKey.getS(), byteLength);
            key.x = JWKUtils.base64EncodeUint(ecPrivateKey.getParams().getGenerator().getAffineX(), byteLength);
            key.y = JWKUtils.base64EncodeUint(ecPrivateKey.getParams().getGenerator().getAffineY(), byteLength);
        } else if (privateKey instanceof EdECPrivateKey) {
            EdECPrivateKey edPrivateKey = (EdECPrivateKey)privateKey;
            key.crv = this.getCurveOID(edPrivateKey);
            key.alg = Algorithm.fromName(key.crv);
            byte[] privateKeyBytes = edPrivateKey.getBytes().orElseThrow(() -> new JSONWebKeyBuilderException("Unable to obtain the private key bytes."));
            key.d = Base64.getUrlEncoder().withoutPadding().encodeToString(privateKeyBytes);
            try {
                byte[] publicKeyBytes = KeyUtils.deriveEdDSAPublicKeyFromPrivate(privateKeyBytes, key.crv);
                key.x = Base64.getUrlEncoder().withoutPadding().encodeToString(publicKeyBytes);
            }
            catch (Exception e) {
                throw new JSONWebKeyBuilderException("Unable to build the public key for the EdDSA private key using curve [" + key.crv + "]", e);
            }
        }
        return key;
    }

    private String getCurveOID(Key key) {
        try {
            return KeyUtils.getCurveName(key);
        }
        catch (Exception e) {
            throw new JSONWebKeyBuilderException("Unable to read the Object Identifier of the public key.", e);
        }
    }

    public JSONWebKey build(PublicKey publicKey) {
        Objects.requireNonNull(publicKey);
        JSONWebKey key = new JSONWebKey();
        key.kty = this.getKeyType(publicKey);
        key.use = "sig";
        if (publicKey instanceof RSAPublicKey) {
            RSAPublicKey rsaPublicKey = (RSAPublicKey)publicKey;
            key.e = JWKUtils.base64EncodeUint(rsaPublicKey.getPublicExponent());
            key.n = JWKUtils.base64EncodeUint(rsaPublicKey.getModulus());
        } else if (key.kty == KeyType.EC) {
            ECPublicKey ecPublicKey = (ECPublicKey)publicKey;
            key.crv = this.getCurveOID(ecPublicKey);
            int length = KeyUtils.getKeyLength(publicKey);
            if (length == 256) {
                key.alg = Algorithm.ES256;
            } else if (length == 384) {
                key.alg = Algorithm.ES384;
            } else if (length == 521) {
                key.alg = Algorithm.ES512;
            }
            int byteLength = this.getCoordinateLength(ecPublicKey);
            key.x = JWKUtils.base64EncodeUint(ecPublicKey.getW().getAffineX(), byteLength);
            key.y = JWKUtils.base64EncodeUint(ecPublicKey.getW().getAffineY(), byteLength);
        } else if (key.kty == KeyType.OKP) {
            byte[] publicKeyBytes;
            key.crv = this.getCurveOID(publicKey);
            key.alg = Algorithm.fromName(key.crv);
            int keyLength = KeyUtils.getKeyLength(publicKey);
            try {
                DerValue[] sequence = new DerInputStream(publicKey.getEncoded()).getSequence();
                publicKeyBytes = sequence[1].toByteArray();
            }
            catch (DerDecodingException e) {
                throw new JSONWebKeyBuilderException("Unable to read the public key from the DER encoded key.", e);
            }
            key.x = JWKUtils.base64EncodeUint(new BigInteger(publicKeyBytes), keyLength);
        }
        return key;
    }

    public JSONWebKey build(Certificate certificate) {
        Objects.requireNonNull(certificate);
        JSONWebKey key = this.build(certificate.getPublicKey());
        if (certificate instanceof X509Certificate) {
            X509Certificate x509Certificate = (X509Certificate)certificate;
            if (key.alg == null) {
                key.alg = this.determineKeyAlgorithm(x509Certificate);
            }
            try {
                String encodedCertificate = new String(Base64.getEncoder().encode(certificate.getEncoded()));
                key.x5c = Collections.singletonList(encodedCertificate);
                key.x5t = JWTUtils.generateJWS_x5t(encodedCertificate);
                key.x5t_256 = JWTUtils.generateJWS_x5t("SHA-256", encodedCertificate);
            }
            catch (CertificateEncodingException e) {
                throw new JSONWebKeyBuilderException("Failed to decode X.509 certificate.", e);
            }
        }
        return key;
    }

    private int getCoordinateLength(ECKey key) {
        return (int)Math.ceil((double)key.getParams().getCurve().getField().getFieldSize() / 8.0);
    }

    private Algorithm determineKeyAlgorithm(X509Certificate x509Certificate) {
        String sigAlgName = x509Certificate.getSigAlgName();
        Algorithm result = Algorithm.fromName(sigAlgName);
        if (result != null) {
            return result;
        }
        if ("RSASSA-PSS".equals(sigAlgName)) {
            byte[] encodedBytes = x509Certificate.getSigAlgParams();
            try {
                String oid;
                result = switch (oid = new DerInputStream(new DerInputStream(encodedBytes).getSequence()[1].toByteArray()).getSequence()[1].getOID().toString()) {
                    case "2.16.840.1.101.3.4.2.1" -> Algorithm.PS256;
                    case "2.16.840.1.101.3.4.2.2" -> Algorithm.PS384;
                    case "2.16.840.1.101.3.4.2.3" -> Algorithm.PS512;
                    default -> null;
                };
            }
            catch (IOException e) {
                throw new JSONWebKeyBuilderException("Failed to decode X.509 certificate signature algorithm parameters to determine the key type.", e);
            }
        }
        return result;
    }

    private KeyType getKeyType(Key key) {
        return switch (key.getAlgorithm()) {
            case "RSA", "RSASSA-PSS" -> KeyType.RSA;
            case "EC" -> KeyType.EC;
            case "EdDSA", "Ed25519", "Ed448" -> KeyType.OKP;
            default -> null;
        };
    }
}

