package se.swedenconnect.opensaml.xmlsec.encryption.support;

import com.google.common.primitives.Bytes;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.interfaces.ECPublicKey;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;
import javax.crypto.KeyAgreement;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import net.shibboleth.utilities.java.support.codec.Base64Support;
import net.shibboleth.utilities.java.support.logic.Constraint;
import org.bouncycastle.asn1.ASN1EncodableVector;
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
import org.bouncycastle.asn1.ASN1Sequence;
import org.bouncycastle.asn1.ASN1StreamParser;
import org.bouncycastle.asn1.DERBitString;
import org.bouncycastle.asn1.DEROutputStream;
import org.bouncycastle.asn1.DERSequence;
import org.bouncycastle.crypto.agreement.kdf.ConcatenationKDFGenerator;
import org.bouncycastle.crypto.digests.RIPEMD160Digest;
import org.bouncycastle.crypto.digests.SHA1Digest;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.digests.SHA384Digest;
import org.bouncycastle.crypto.digests.SHA512Digest;
import org.bouncycastle.crypto.params.KDFParameters;
import org.bouncycastle.jce.spec.ECNamedCurveGenParameterSpec;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.security.SecurityException;
import org.opensaml.security.credential.Credential;
import org.opensaml.xmlsec.algorithm.AlgorithmSupport;
import org.opensaml.xmlsec.encryption.AgreementMethod;
import org.opensaml.xmlsec.encryption.OriginatorKeyInfo;
import org.opensaml.xmlsec.signature.DEREncodedKeyValue;
import org.opensaml.xmlsec.signature.ECKeyValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import se.swedenconnect.opensaml.security.credential.KeyAgreementCredential;
import se.swedenconnect.opensaml.xmlsec.algorithm.descriptors.NamedCurve;
import se.swedenconnect.opensaml.xmlsec.algorithm.descriptors.NamedCurveRegistry;
import se.swedenconnect.opensaml.xmlsec.encryption.ConcatKDFParams;
import se.swedenconnect.opensaml.xmlsec.encryption.KeyDerivationMethod;

/* loaded from: input_file:se/swedenconnect/opensaml/xmlsec/encryption/support/ECDHSupport.class */
public class ECDHSupport {
    private static final Logger log = LoggerFactory.getLogger(ECDHSupport.class);
    public static final String EC_PUBLIC_KEY_OID = "1.2.840.10045.2.1";

    public static KeyAgreementCredential createKeyAgreementCredential(Credential credential, String str, KeyDerivationMethod keyDerivationMethod) throws SecurityException {
        Constraint.isNotNull(credential, "peerCredential must not be null");
        Constraint.isNotNull(credential.getPublicKey(), "peerCredential must contain a public key");
        Constraint.isTrue(ECPublicKey.class.isInstance(credential.getPublicKey()), "Public key of peerCredential must be an ECPublicKey");
        Constraint.isNotNull(str, "keyWrappingAlgorithm must not be null");
        Constraint.isNotNull(keyDerivationMethod, "keyDerivationMethod must not be null");
        Constraint.isTrue(EcEncryptionConstants.ALGO_ID_KEYDERIVATION_CONCAT.equals(keyDerivationMethod.getAlgorithm()), String.format("{} key derivation method is not supported - {} is required", keyDerivationMethod.getAlgorithm(), EcEncryptionConstants.ALGO_ID_KEYDERIVATION_CONCAT));
        Stream stream = keyDerivationMethod.getUnknownXMLObjects(ConcatKDFParams.DEFAULT_ELEMENT_NAME).stream();
        Class<ConcatKDFParams> cls = ConcatKDFParams.class;
        ConcatKDFParams.class.getClass();
        ConcatKDFParams concatKDFParams = (ConcatKDFParams) stream.map((v1) -> {
            return r1.cast(v1);
        }).findFirst().orElse(null);
        Constraint.isNotNull(concatKDFParams, "ConcatKDF params is missing from KeyDerivationMethod");
        NamedCurve namedCurve = getNamedCurve((ECPublicKey) credential.getPublicKey());
        if (namedCurve == null) {
            throw new SecurityException("Unsupported named curve in EC public key");
        }
        try {
            log.debug("Generating EC key pair for named curve {} ...", namedCurve.getName());
            AlgorithmParameterSpec eCNamedCurveGenParameterSpec = new ECNamedCurveGenParameterSpec(namedCurve.getName());
            KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC", "BC");
            keyPairGenerator.initialize(eCNamedCurveGenParameterSpec);
            KeyPair generateKeyPair = keyPairGenerator.generateKeyPair();
            try {
                log.debug("Generating shared secret for ECDH key agreement ...");
                KeyAgreement keyAgreement = KeyAgreement.getInstance("ECDH", "BC");
                keyAgreement.init(generateKeyPair.getPrivate());
                keyAgreement.doPhase(credential.getPublicKey(), true);
                byte[] generateSecret = keyAgreement.generateSecret();
                String algorithmID = AlgorithmSupport.getAlgorithmID(str);
                if (str == null) {
                    String format = String.format("Algorithm %s is not supported", str);
                    log.error(format);
                    throw new SecurityException(format);
                }
                Integer keyLength = AlgorithmSupport.getKeyLength(str);
                if (keyLength != null) {
                    log.debug("Generating key agreement key ...");
                    return new KeyAgreementCredential(generateKeyAgreementKey(generateSecret, concatKDFParams, algorithmID, keyLength.intValue()), generateKeyPair.getPublic(), credential, EcEncryptionConstants.ALGO_ID_KEYAGREEMENT_ECDH_ES, keyDerivationMethod);
                }
                String format2 = String.format("Unknown key size for algorithm %s - cannot proceed", str);
                log.error(format2);
                throw new SecurityException(format2);
            } catch (InvalidKeyException | NoSuchAlgorithmException | NoSuchProviderException e) {
                log.error("{}", "Failed to generate shared secret for ECDH key agreement", e);
                throw new SecurityException("Failed to generate shared secret for ECDH key agreement", e);
            }
        } catch (InvalidAlgorithmParameterException | NoSuchAlgorithmException | NoSuchProviderException e2) {
            String format3 = String.format("Failed to generate an EC key pair for curve %s - %s", namedCurve.getName(), e2.getMessage());
            log.error("{}", format3, e2);
            throw new SecurityException(format3, e2);
        }
    }

    public static SecretKey getKeyAgreementKey(PrivateKey privateKey, AgreementMethod agreementMethod, String str, int i) throws SecurityException {
        Constraint.isNotNull(privateKey, "decrypterKey must not be null");
        Constraint.isNotNull(privateKey, "keyWrappingAlgorithm must not be null");
        Constraint.isNotNull(agreementMethod, "agreementMethod must not be null");
        try {
            if (!EcEncryptionConstants.ALGO_ID_KEYAGREEMENT_ECDH_ES.equals(agreementMethod.getAlgorithm())) {
                throw new SecurityException("Unsupported agreement method algorithm - " + agreementMethod.getAlgorithm());
            }
            List unknownXMLObjects = agreementMethod.getUnknownXMLObjects(KeyDerivationMethod.DEFAULT_ELEMENT_NAME);
            if (unknownXMLObjects.isEmpty()) {
                throw new SecurityException("No KeyDerivationMethod element found under supplied AgreementMethod");
            }
            KeyDerivationMethod keyDerivationMethod = (KeyDerivationMethod) KeyDerivationMethod.class.cast(unknownXMLObjects.get(0));
            if (!EcEncryptionConstants.ALGO_ID_KEYDERIVATION_CONCAT.equals(keyDerivationMethod.getAlgorithm())) {
                throw new SecurityException("Unsupported key derivation method - " + keyDerivationMethod.getAlgorithm());
            }
            List unknownXMLObjects2 = keyDerivationMethod.getUnknownXMLObjects(ConcatKDFParams.DEFAULT_ELEMENT_NAME);
            if (unknownXMLObjects2.isEmpty()) {
                throw new SecurityException("Missing ConcatKDFParams under KeyDerivation element");
            }
            ConcatKDFParams concatKDFParams = (ConcatKDFParams) ConcatKDFParams.class.cast(unknownXMLObjects2.get(0));
            if (agreementMethod.getOriginatorKeyInfo() == null) {
                throw new SecurityException("Missing OriginatorKeyInfo - need generated public key");
            }
            OriginatorKeyInfo originatorKeyInfo = agreementMethod.getOriginatorKeyInfo();
            byte[] bArr = null;
            if (!originatorKeyInfo.getKeyValues().isEmpty()) {
                ECKeyValue eCKeyValue = (ECKeyValue) originatorKeyInfo.getKeyValues().stream().filter(keyValue -> {
                    return keyValue.getECKeyValue() != null;
                }).map(keyValue2 -> {
                    return keyValue2.getECKeyValue();
                }).findFirst().orElse(null);
                if (eCKeyValue != null) {
                    bArr = getPublicKeyBytes(Base64Support.decode(eCKeyValue.getPublicKey().getValue()), eCKeyValue.getNamedCurve().getURI());
                }
            } else if (!originatorKeyInfo.getDEREncodedKeyValues().isEmpty()) {
                bArr = Base64Support.decode(((DEREncodedKeyValue) originatorKeyInfo.getDEREncodedKeyValues().get(0)).getValue());
            }
            if (bArr == null) {
                throw new SecurityException("Could not find generated public key in OriginatorKeyInfo");
            }
            PublicKey generatePublic = KeyFactory.getInstance("EC", "BC").generatePublic(new X509EncodedKeySpec(bArr));
            KeyAgreement keyAgreement = KeyAgreement.getInstance("ECDH", "BC");
            keyAgreement.init(privateKey);
            keyAgreement.doPhase(generatePublic, true);
            return generateKeyAgreementKey(keyAgreement.generateSecret(), concatKDFParams, str, i);
        } catch (InvalidKeyException e) {
            throw new SecurityException("Failed to generate shared secret", e);
        } catch (NoSuchAlgorithmException | NoSuchProviderException | InvalidKeySpecException e2) {
            throw new SecurityException("Failed to generate key - " + e2.getMessage(), e2);
        }
    }

    /* JADX WARN: Type inference failed for: r0v13, types: [byte[], byte[][]] */
    /* JADX WARN: Type inference failed for: r0v57, types: [byte[], byte[][]] */
    /* JADX WARN: Type inference failed for: r0v60, types: [byte[], byte[][]] */
    private static SecretKey generateKeyAgreementKey(byte[] bArr, ConcatKDFParams concatKDFParams, String str, int i) throws SecurityException {
        SHA256Digest rIPEMD160Digest;
        if (concatKDFParams.getDigestMethod() == null || concatKDFParams.getDigestMethod().getAlgorithm() == null) {
            throw new SecurityException("Missing digest method in ConcatKDFParams");
        }
        if (concatKDFParams.getAlgorithmID() == null) {
            throw new SecurityException("Missing AlgorithmID attribute from ConcatKDFParams");
        }
        if (concatKDFParams.getPartyUInfo() == null) {
            throw new SecurityException("Missing PartyUInfo attribute from ConcatKDFParams");
        }
        if (concatKDFParams.getPartyVInfo() == null) {
            throw new SecurityException("Missing PartyVInfo attribute from ConcatKDFParams");
        }
        byte[] concat = Bytes.concat((byte[][]) new byte[]{extractConcatKDFParamVal(concatKDFParams.getAlgorithmID()), extractConcatKDFParamVal(concatKDFParams.getPartyUInfo()), extractConcatKDFParamVal(concatKDFParams.getPartyVInfo())});
        if (concatKDFParams.getSuppPubInfo() != null) {
            concat = Bytes.concat((byte[][]) new byte[]{concat, extractConcatKDFParamVal(concatKDFParams.getSuppPubInfo())});
        }
        if (concatKDFParams.getSuppPrivInfo() != null) {
            concat = Bytes.concat((byte[][]) new byte[]{concat, extractConcatKDFParamVal(concatKDFParams.getSuppPrivInfo())});
        }
        if (concatKDFParams.getDigestMethod().getAlgorithm().equals("http://www.w3.org/2001/04/xmlenc#sha256")) {
            rIPEMD160Digest = new SHA256Digest();
        } else if (concatKDFParams.getDigestMethod().getAlgorithm().equals("http://www.w3.org/2001/04/xmlenc#sha512")) {
            rIPEMD160Digest = new SHA512Digest();
        } else if (concatKDFParams.getDigestMethod().getAlgorithm().equals("http://www.w3.org/2000/09/xmldsig#sha1")) {
            rIPEMD160Digest = new SHA1Digest();
        } else if (concatKDFParams.getDigestMethod().getAlgorithm().equals("http://www.w3.org/2001/04/xmldsig-more#sha384")) {
            rIPEMD160Digest = new SHA384Digest();
        } else {
            if (!concatKDFParams.getDigestMethod().getAlgorithm().equals("http://www.w3.org/2001/04/xmlenc#ripemd160")) {
                throw new SecurityException("ConcatKDFParams contains unsupported digest algorithm - " + concatKDFParams.getDigestMethod().getAlgorithm());
            }
            rIPEMD160Digest = new RIPEMD160Digest();
        }
        ConcatenationKDFGenerator concatenationKDFGenerator = new ConcatenationKDFGenerator(rIPEMD160Digest);
        concatenationKDFGenerator.init(new KDFParameters(bArr, concat));
        int i2 = i / 8;
        byte[] bArr2 = new byte[i2];
        concatenationKDFGenerator.generateBytes(bArr2, 0, i2);
        return new SecretKeySpec(bArr2, str);
    }

    private static byte[] extractConcatKDFParamVal(byte[] bArr) throws SecurityException {
        if (bArr == null) {
            return new byte[0];
        }
        if (bArr.length == 0) {
            return new byte[0];
        }
        if (bArr[0] == 8 && bArr.length > 1) {
            return Arrays.copyOfRange(bArr, 2, bArr.length);
        }
        if (bArr[0] != 0) {
            throw new IllegalArgumentException("Unsupported use of padding bits in ConcatKDF parameters");
        }
        return Arrays.copyOfRange(bArr, 1, bArr.length);
    }

    private static byte[] getPublicKeyBytes(byte[] bArr, String str) throws SecurityException {
        ASN1EncodableVector aSN1EncodableVector = new ASN1EncodableVector();
        aSN1EncodableVector.add(new ASN1ObjectIdentifier(EC_PUBLIC_KEY_OID));
        aSN1EncodableVector.add(new ASN1ObjectIdentifier(str.startsWith("urn:oid:") ? str.substring(8) : str));
        ASN1EncodableVector aSN1EncodableVector2 = new ASN1EncodableVector();
        aSN1EncodableVector2.add(new DERSequence(aSN1EncodableVector));
        aSN1EncodableVector2.add(new DERBitString(bArr));
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        DEROutputStream dEROutputStream = new DEROutputStream(byteArrayOutputStream);
        try {
            try {
                dEROutputStream.writeObject(new DERSequence(aSN1EncodableVector2));
                return byteArrayOutputStream.toByteArray();
            } catch (IOException e) {
                throw new SecurityException("Failed to get EC public key bytes", e);
            }
        } finally {
            try {
                dEROutputStream.close();
                byteArrayOutputStream.close();
            } catch (IOException e2) {
            }
        }
    }

    public static NamedCurve getNamedCurve(ECPublicKey eCPublicKey) {
        try {
            ASN1Sequence aSN1Primitive = new ASN1StreamParser(eCPublicKey.getEncoded()).readObject().toASN1Primitive().getObjectAt(0).toASN1Primitive();
            ASN1ObjectIdentifier aSN1Primitive2 = aSN1Primitive.getObjectAt(0).toASN1Primitive();
            if (!aSN1Primitive2.getId().equals(EC_PUBLIC_KEY_OID)) {
                log.error("The provided public key with key type OID {} is not a valid EC public key", aSN1Primitive2.getId());
                return null;
            }
            ASN1ObjectIdentifier aSN1Primitive3 = aSN1Primitive.getObjectAt(1).toASN1Primitive();
            log.debug("Asking NamedCurveRegistry for curve having OID {} ...", aSN1Primitive3);
            NamedCurveRegistry namedCurveRegistry = (NamedCurveRegistry) ConfigurationService.get(NamedCurveRegistry.class);
            if (namedCurveRegistry == null) {
                throw new RuntimeException("NamedCurveRegistry is not available");
            }
            NamedCurve namedCurve = namedCurveRegistry.get(aSN1Primitive3.getId());
            if (namedCurve != null) {
                log.debug("Looked up NamedCurve {} ({}) (keyLength:{})", new Object[]{namedCurve.getObjectIdentifier(), namedCurve.getName(), namedCurve.getKeyLength()});
                return namedCurve;
            }
            log.debug("NamedCurve with OID {} was not found in the NamedCurveRegistry", aSN1Primitive3.getId());
            return null;
        } catch (IOException | NullPointerException e) {
            log.error("Unable to parse the provided public key as an EC public key based on a named EC curve", e);
            return null;
        }
    }
}
