/*
 * Decompiled with CFR 0.152.
 */
package se.swedenconnect.opensaml.xmlsec.encryption.support;

import com.google.common.primitives.Bytes;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.interfaces.ECPublicKey;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.util.Arrays;
import java.util.List;
import javax.crypto.KeyAgreement;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import net.shibboleth.utilities.java.support.codec.Base64Support;
import net.shibboleth.utilities.java.support.logic.Constraint;
import org.bouncycastle.asn1.ASN1Encodable;
import org.bouncycastle.asn1.ASN1EncodableVector;
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.ASN1Sequence;
import org.bouncycastle.asn1.ASN1StreamParser;
import org.bouncycastle.asn1.DERBitString;
import org.bouncycastle.asn1.DEROutputStream;
import org.bouncycastle.asn1.DERSequence;
import org.bouncycastle.crypto.DerivationParameters;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.agreement.kdf.ConcatenationKDFGenerator;
import org.bouncycastle.crypto.digests.RIPEMD160Digest;
import org.bouncycastle.crypto.digests.SHA1Digest;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.digests.SHA384Digest;
import org.bouncycastle.crypto.digests.SHA512Digest;
import org.bouncycastle.crypto.params.KDFParameters;
import org.bouncycastle.jce.spec.ECNamedCurveGenParameterSpec;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.security.SecurityException;
import org.opensaml.security.credential.Credential;
import org.opensaml.xmlsec.algorithm.AlgorithmSupport;
import org.opensaml.xmlsec.encryption.AgreementMethod;
import org.opensaml.xmlsec.encryption.OriginatorKeyInfo;
import org.opensaml.xmlsec.signature.DEREncodedKeyValue;
import org.opensaml.xmlsec.signature.ECKeyValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import se.swedenconnect.opensaml.security.credential.KeyAgreementCredential;
import se.swedenconnect.opensaml.xmlsec.algorithm.descriptors.NamedCurve;
import se.swedenconnect.opensaml.xmlsec.algorithm.descriptors.NamedCurveRegistry;
import se.swedenconnect.opensaml.xmlsec.encryption.ConcatKDFParams;
import se.swedenconnect.opensaml.xmlsec.encryption.KeyDerivationMethod;

public class ECDHSupport {
    private static final Logger log = LoggerFactory.getLogger(ECDHSupport.class);
    public static final String EC_PUBLIC_KEY_OID = "1.2.840.10045.2.1";

    public static KeyAgreementCredential createKeyAgreementCredential(Credential peerCredential, String keyWrappingAlgorithm, KeyDerivationMethod keyDerivationMethod) throws SecurityException {
        Constraint.isNotNull((Object)peerCredential, (String)"peerCredential must not be null");
        Constraint.isNotNull((Object)peerCredential.getPublicKey(), (String)"peerCredential must contain a public key");
        Constraint.isTrue((boolean)ECPublicKey.class.isInstance(peerCredential.getPublicKey()), (String)"Public key of peerCredential must be an ECPublicKey");
        Constraint.isNotNull((Object)keyWrappingAlgorithm, (String)"keyWrappingAlgorithm must not be null");
        Constraint.isNotNull((Object)keyDerivationMethod, (String)"keyDerivationMethod must not be null");
        Constraint.isTrue((boolean)"http://www.w3.org/2009/xmlenc11#ConcatKDF".equals(keyDerivationMethod.getAlgorithm()), (String)String.format("{} key derivation method is not supported - {} is required", keyDerivationMethod.getAlgorithm(), "http://www.w3.org/2009/xmlenc11#ConcatKDF"));
        ConcatKDFParams concatKDFParams = keyDerivationMethod.getUnknownXMLObjects(ConcatKDFParams.DEFAULT_ELEMENT_NAME).stream().map(ConcatKDFParams.class::cast).findFirst().orElse(null);
        Constraint.isNotNull((Object)concatKDFParams, (String)"ConcatKDF params is missing from KeyDerivationMethod");
        NamedCurve namedCurve = ECDHSupport.getNamedCurve((ECPublicKey)peerCredential.getPublicKey());
        if (namedCurve == null) {
            throw new SecurityException("Unsupported named curve in EC public key");
        }
        KeyPair generatedKeyPair = null;
        try {
            log.debug("Generating EC key pair for named curve {} ...", (Object)namedCurve.getName());
            ECNamedCurveGenParameterSpec parameterSpec = new ECNamedCurveGenParameterSpec(namedCurve.getName());
            KeyPairGenerator kpg = KeyPairGenerator.getInstance("EC", "BC");
            kpg.initialize((AlgorithmParameterSpec)parameterSpec);
            generatedKeyPair = kpg.generateKeyPair();
        }
        catch (InvalidAlgorithmParameterException | NoSuchAlgorithmException | NoSuchProviderException e) {
            String msg = String.format("Failed to generate an EC key pair for curve %s - %s", namedCurve.getName(), e.getMessage());
            log.error("{}", (Object)msg, (Object)e);
            throw new SecurityException(msg, (Exception)e);
        }
        SecretKey keyAgreementKey = null;
        try {
            log.debug("Generating shared secret for ECDH key agreement ...");
            KeyAgreement ka = KeyAgreement.getInstance("ECDH", "BC");
            ka.init(generatedKeyPair.getPrivate());
            ka.doPhase(peerCredential.getPublicKey(), true);
            byte[] sharedSecret = ka.generateSecret();
            String keyWrappingJcaAlgorithmId = AlgorithmSupport.getAlgorithmID((String)keyWrappingAlgorithm);
            if (keyWrappingAlgorithm == null) {
                String msg = String.format("Algorithm %s is not supported", keyWrappingAlgorithm);
                log.error(msg);
                throw new SecurityException(msg);
            }
            Integer keyWrappingKeySize = AlgorithmSupport.getKeyLength((String)keyWrappingAlgorithm);
            if (keyWrappingKeySize == null) {
                String msg = String.format("Unknown key size for algorithm %s - cannot proceed", keyWrappingAlgorithm);
                log.error(msg);
                throw new SecurityException(msg);
            }
            log.debug("Generating key agreement key ...");
            keyAgreementKey = ECDHSupport.generateKeyAgreementKey(sharedSecret, concatKDFParams, keyWrappingJcaAlgorithmId, keyWrappingKeySize);
        }
        catch (InvalidKeyException | NoSuchAlgorithmException | NoSuchProviderException e) {
            String msg = "Failed to generate shared secret for ECDH key agreement";
            log.error("{}", (Object)msg, (Object)e);
            throw new SecurityException(msg, (Exception)e);
        }
        return new KeyAgreementCredential(keyAgreementKey, generatedKeyPair.getPublic(), peerCredential, "http://www.w3.org/2009/xmlenc11#ECDH-ES", keyDerivationMethod);
    }

    public static SecretKey getKeyAgreementKey(PrivateKey decrypterKey, AgreementMethod agreementMethod, String keyWrappingJcaAlgorithmId, int keyWrappingKeySize) throws SecurityException {
        Constraint.isNotNull((Object)decrypterKey, (String)"decrypterKey must not be null");
        Constraint.isNotNull((Object)decrypterKey, (String)"keyWrappingAlgorithm must not be null");
        Constraint.isNotNull((Object)agreementMethod, (String)"agreementMethod must not be null");
        try {
            if (!"http://www.w3.org/2009/xmlenc11#ECDH-ES".equals(agreementMethod.getAlgorithm())) {
                throw new SecurityException("Unsupported agreement method algorithm - " + agreementMethod.getAlgorithm());
            }
            List kdms = agreementMethod.getUnknownXMLObjects(KeyDerivationMethod.DEFAULT_ELEMENT_NAME);
            if (kdms.isEmpty()) {
                throw new SecurityException("No KeyDerivationMethod element found under supplied AgreementMethod");
            }
            KeyDerivationMethod keyDerivationMethod = (KeyDerivationMethod)KeyDerivationMethod.class.cast(kdms.get(0));
            if (!"http://www.w3.org/2009/xmlenc11#ConcatKDF".equals(keyDerivationMethod.getAlgorithm())) {
                throw new SecurityException("Unsupported key derivation method - " + keyDerivationMethod.getAlgorithm());
            }
            List pars = keyDerivationMethod.getUnknownXMLObjects(ConcatKDFParams.DEFAULT_ELEMENT_NAME);
            if (pars.isEmpty()) {
                throw new SecurityException("Missing ConcatKDFParams under KeyDerivation element");
            }
            ConcatKDFParams concatKDFParams = (ConcatKDFParams)ConcatKDFParams.class.cast(pars.get(0));
            if (agreementMethod.getOriginatorKeyInfo() == null) {
                throw new SecurityException("Missing OriginatorKeyInfo - need generated public key");
            }
            OriginatorKeyInfo originatorKeyInfo = agreementMethod.getOriginatorKeyInfo();
            byte[] encodedPublicKey = null;
            if (!originatorKeyInfo.getKeyValues().isEmpty()) {
                ECKeyValue ecKeyValue = originatorKeyInfo.getKeyValues().stream().filter(v -> v.getECKeyValue() != null).map(v -> v.getECKeyValue()).findFirst().orElse(null);
                if (ecKeyValue != null) {
                    encodedPublicKey = ECDHSupport.getPublicKeyBytes(Base64Support.decode((String)ecKeyValue.getPublicKey().getValue()), ecKeyValue.getNamedCurve().getURI());
                }
            } else if (!originatorKeyInfo.getDEREncodedKeyValues().isEmpty()) {
                encodedPublicKey = Base64Support.decode((String)((DEREncodedKeyValue)originatorKeyInfo.getDEREncodedKeyValues().get(0)).getValue());
            }
            if (encodedPublicKey == null) {
                throw new SecurityException("Could not find generated public key in OriginatorKeyInfo");
            }
            KeyFactory keyFactory = KeyFactory.getInstance("EC", "BC");
            X509EncodedKeySpec x509EncodedKeySpec = new X509EncodedKeySpec(encodedPublicKey);
            PublicKey publicKey = keyFactory.generatePublic(x509EncodedKeySpec);
            KeyAgreement ka = KeyAgreement.getInstance("ECDH", "BC");
            ka.init(decrypterKey);
            ka.doPhase(publicKey, true);
            byte[] sharedSecret = ka.generateSecret();
            return ECDHSupport.generateKeyAgreementKey(sharedSecret, concatKDFParams, keyWrappingJcaAlgorithmId, keyWrappingKeySize);
        }
        catch (NoSuchAlgorithmException | NoSuchProviderException | InvalidKeySpecException e) {
            throw new SecurityException("Failed to generate key - " + e.getMessage(), (Exception)e);
        }
        catch (InvalidKeyException e) {
            throw new SecurityException("Failed to generate shared secret", (Exception)e);
        }
    }

    private static SecretKey generateKeyAgreementKey(byte[] sharedSecret, ConcatKDFParams concatKDFParams, String keyWrappingJcaAlgorithmId, int keyWrappingKeySize) throws SecurityException {
        if (concatKDFParams.getDigestMethod() == null || concatKDFParams.getDigestMethod().getAlgorithm() == null) {
            throw new SecurityException("Missing digest method in ConcatKDFParams");
        }
        if (concatKDFParams.getAlgorithmID() == null) {
            throw new SecurityException("Missing AlgorithmID attribute from ConcatKDFParams");
        }
        if (concatKDFParams.getPartyUInfo() == null) {
            throw new SecurityException("Missing PartyUInfo attribute from ConcatKDFParams");
        }
        if (concatKDFParams.getPartyVInfo() == null) {
            throw new SecurityException("Missing PartyVInfo attribute from ConcatKDFParams");
        }
        byte[] combinedConcatParams = Bytes.concat((byte[][])new byte[][]{ECDHSupport.extractConcatKDFParamVal(concatKDFParams.getAlgorithmID()), ECDHSupport.extractConcatKDFParamVal(concatKDFParams.getPartyUInfo()), ECDHSupport.extractConcatKDFParamVal(concatKDFParams.getPartyVInfo())});
        if (concatKDFParams.getSuppPubInfo() != null) {
            combinedConcatParams = Bytes.concat((byte[][])new byte[][]{combinedConcatParams, ECDHSupport.extractConcatKDFParamVal(concatKDFParams.getSuppPubInfo())});
        }
        if (concatKDFParams.getSuppPrivInfo() != null) {
            combinedConcatParams = Bytes.concat((byte[][])new byte[][]{combinedConcatParams, ECDHSupport.extractConcatKDFParamVal(concatKDFParams.getSuppPrivInfo())});
        }
        SHA256Digest digest = null;
        if (concatKDFParams.getDigestMethod().getAlgorithm().equals("http://www.w3.org/2001/04/xmlenc#sha256")) {
            digest = new SHA256Digest();
        } else if (concatKDFParams.getDigestMethod().getAlgorithm().equals("http://www.w3.org/2001/04/xmlenc#sha512")) {
            digest = new SHA512Digest();
        } else if (concatKDFParams.getDigestMethod().getAlgorithm().equals("http://www.w3.org/2000/09/xmldsig#sha1")) {
            digest = new SHA1Digest();
        } else if (concatKDFParams.getDigestMethod().getAlgorithm().equals("http://www.w3.org/2001/04/xmldsig-more#sha384")) {
            digest = new SHA384Digest();
        } else if (concatKDFParams.getDigestMethod().getAlgorithm().equals("http://www.w3.org/2001/04/xmlenc#ripemd160")) {
            digest = new RIPEMD160Digest();
        } else {
            throw new SecurityException("ConcatKDFParams contains unsupported digest algorithm - " + concatKDFParams.getDigestMethod().getAlgorithm());
        }
        ConcatenationKDFGenerator concatKDF = new ConcatenationKDFGenerator((Digest)digest);
        KDFParameters kdfParams = new KDFParameters(sharedSecret, combinedConcatParams);
        concatKDF.init((DerivationParameters)kdfParams);
        int keyLength = keyWrappingKeySize / 8;
        byte[] rawKey = new byte[keyLength];
        concatKDF.generateBytes(rawKey, 0, keyLength);
        return new SecretKeySpec(rawKey, keyWrappingJcaAlgorithmId);
    }

    private static byte[] extractConcatKDFParamVal(byte[] paddedParam) throws SecurityException {
        if (paddedParam == null) {
            return new byte[0];
        }
        if (paddedParam.length == 0) {
            return new byte[0];
        }
        if (paddedParam[0] == 8 && paddedParam.length > 1) {
            return Arrays.copyOfRange(paddedParam, 2, paddedParam.length);
        }
        if (paddedParam[0] != 0) {
            throw new IllegalArgumentException("Unsupported use of padding bits in ConcatKDF parameters");
        }
        return Arrays.copyOfRange(paddedParam, 1, paddedParam.length);
    }

    private static byte[] getPublicKeyBytes(byte[] publicKeyBytes, String curveOidUri) throws SecurityException {
        ASN1EncodableVector publicKeyParamSeq = new ASN1EncodableVector();
        publicKeyParamSeq.add((ASN1Encodable)new ASN1ObjectIdentifier(EC_PUBLIC_KEY_OID));
        String oid = curveOidUri.startsWith("urn:oid:") ? curveOidUri.substring(8) : curveOidUri;
        publicKeyParamSeq.add((ASN1Encodable)new ASN1ObjectIdentifier(oid));
        ASN1EncodableVector publicKeySeq = new ASN1EncodableVector();
        publicKeySeq.add((ASN1Encodable)new DERSequence(publicKeyParamSeq));
        publicKeySeq.add((ASN1Encodable)new DERBitString(publicKeyBytes));
        ByteArrayOutputStream bout = new ByteArrayOutputStream();
        DEROutputStream dout = new DEROutputStream((OutputStream)bout);
        try {
            dout.writeObject((ASN1Primitive)new DERSequence(publicKeySeq));
            byte[] byArray = bout.toByteArray();
            return byArray;
        }
        catch (IOException e) {
            throw new SecurityException("Failed to get EC public key bytes", (Exception)e);
        }
        finally {
            try {
                dout.close();
                bout.close();
            }
            catch (IOException iOException) {}
        }
    }

    public static NamedCurve getNamedCurve(ECPublicKey publicKey) {
        try {
            ASN1StreamParser parser = new ASN1StreamParser(publicKey.getEncoded());
            ASN1Sequence seq = (ASN1Sequence)parser.readObject().toASN1Primitive();
            ASN1Sequence innerSeq = (ASN1Sequence)seq.getObjectAt(0).toASN1Primitive();
            ASN1ObjectIdentifier ecPubKeyoid = (ASN1ObjectIdentifier)innerSeq.getObjectAt(0).toASN1Primitive();
            if (!ecPubKeyoid.getId().equals(EC_PUBLIC_KEY_OID)) {
                log.error("The provided public key with key type OID {} is not a valid EC public key", (Object)ecPubKeyoid.getId());
                return null;
            }
            ASN1ObjectIdentifier oid = (ASN1ObjectIdentifier)innerSeq.getObjectAt(1).toASN1Primitive();
            log.debug("Asking NamedCurveRegistry for curve having OID {} ...", (Object)oid);
            NamedCurveRegistry registry = (NamedCurveRegistry)ConfigurationService.get(NamedCurveRegistry.class);
            if (registry == null) {
                throw new RuntimeException("NamedCurveRegistry is not available");
            }
            NamedCurve curve = registry.get(oid.getId());
            if (curve != null) {
                log.debug("Looked up NamedCurve {} ({}) (keyLength:{})", new Object[]{curve.getObjectIdentifier(), curve.getName(), curve.getKeyLength()});
                return curve;
            }
            log.debug("NamedCurve with OID {} was not found in the NamedCurveRegistry", (Object)oid.getId());
            return null;
        }
        catch (IOException | NullPointerException e) {
            log.error("Unable to parse the provided public key as an EC public key based on a named EC curve", (Throwable)e);
            return null;
        }
    }
}

