/*
 * Decompiled with CFR 0.152.
 */
package io.helidon.security.jwt;

import io.helidon.common.Errors;
import io.helidon.security.jwt.JwtException;
import io.helidon.security.jwt.JwtHeaders;
import io.helidon.security.jwt.SignedJwt;
import io.helidon.security.jwt.jwk.Jwk;
import io.helidon.security.jwt.jwk.JwkEC;
import io.helidon.security.jwt.jwk.JwkKeys;
import io.helidon.security.jwt.jwk.JwkRSA;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
import java.util.Arrays;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.Mac;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

public final class EncryptedJwt {
    private static final Map<SupportedAlgorithm, String> RSA_ALGORITHMS;
    private static final Map<SupportedEncryption, AesAlgorithm> CONTENT_ENCRYPTION;
    private static final Pattern JWE_PATTERN;
    private static final Base64.Decoder URL_DECODER;
    private static final Base64.Encoder URL_ENCODER;
    private final String token;
    private final JwtHeaders header;
    private final byte[] iv;
    private final byte[] encryptedKey;
    private final byte[] authTag;
    private final byte[] encryptedPayload;

    private EncryptedJwt(String token, JwtHeaders header, byte[] iv, byte[] encryptedKey, byte[] authTag, byte[] encryptedPayload) {
        this.token = token;
        this.header = header;
        this.iv = iv;
        this.encryptedKey = encryptedKey;
        this.authTag = authTag;
        this.encryptedPayload = encryptedPayload;
    }

    public static Builder builder(SignedJwt jwt) {
        return new Builder(jwt);
    }

    public static EncryptedJwt create(SignedJwt jwt, Jwk jwk) {
        return EncryptedJwt.builder(jwt).jwk(jwk).build();
    }

    public static EncryptedJwt parseToken(String token) {
        Errors.Collector collector = Errors.collector();
        Matcher matcher = JWE_PATTERN.matcher(token);
        if (matcher.matches()) {
            String headerBase64 = matcher.group(1);
            String encryptedKeyBase64 = matcher.group(2);
            String ivBase64 = matcher.group(3);
            String payloadBase64 = matcher.group(4);
            String authTagBase64 = matcher.group(5);
            JwtHeaders header = JwtHeaders.parseBase64(headerBase64, collector);
            return EncryptedJwt.parse(token, collector, header, encryptedKeyBase64, ivBase64, payloadBase64, authTagBase64);
        }
        throw new JwtException("Not a JWE token: " + token);
    }

    public static EncryptedJwt parseToken(JwtHeaders header, String token) {
        Errors.Collector collector = Errors.collector();
        Matcher matcher = JWE_PATTERN.matcher(token);
        if (matcher.matches()) {
            String encryptedKeyBase64 = matcher.group(2);
            String ivBase64 = matcher.group(3);
            String payloadBase64 = matcher.group(4);
            String authTagBase64 = matcher.group(5);
            return EncryptedJwt.parse(token, collector, header, encryptedKeyBase64, ivBase64, payloadBase64, authTagBase64);
        }
        throw new JwtException("Not a JWE token: " + token);
    }

    private static EncryptedJwt parse(String token, Errors.Collector collector, JwtHeaders header, String encryptedKeyBase64, String ivBase64, String payloadBase64, String authTagBase64) {
        byte[] encryptedKey = EncryptedJwt.decodeBytes(encryptedKeyBase64, collector, "JWE encrypted key");
        byte[] iv = EncryptedJwt.decodeBytes(ivBase64, collector, "JWE initialization vector");
        byte[] encryptedPayload = EncryptedJwt.decodeBytes(payloadBase64, collector, "JWE payload");
        byte[] authTag = EncryptedJwt.decodeBytes(authTagBase64, collector, "JWE authentication tag");
        collector.collect().checkValid();
        return new EncryptedJwt(token, header, iv, encryptedKey, authTag, encryptedPayload);
    }

    private static byte[] encryptRsa(String algorithm, PublicKey publicKey, byte[] unencryptedKey) {
        try {
            Cipher rsaCipher = Cipher.getInstance(algorithm);
            rsaCipher.init(1, publicKey);
            return rsaCipher.doFinal(unencryptedKey);
        }
        catch (Exception e) {
            throw new JwtException("Exception during rsa key decryption occurred.", e);
        }
    }

    private static byte[] decryptRsa(String algorithm, PrivateKey privateKey, byte[] encryptedKey) {
        try {
            Cipher rsaCipher = Cipher.getInstance(algorithm);
            rsaCipher.init(2, privateKey);
            return rsaCipher.doFinal(encryptedKey);
        }
        catch (Exception e) {
            throw new JwtException("Exception during rsa key decryption occurred.", e);
        }
    }

    private static String encode(String string) {
        return EncryptedJwt.encode(string.getBytes(StandardCharsets.UTF_8));
    }

    private static String encode(byte[] bytes) {
        return URL_ENCODER.encodeToString(bytes);
    }

    private static byte[] decodeBytes(String base64, Errors.Collector collector, String description) {
        try {
            return URL_DECODER.decode(base64);
        }
        catch (Exception e) {
            collector.fatal((Object)base64, description + " is not a base64 encoded string.");
            return null;
        }
    }

    public SignedJwt decrypt(JwkKeys jwkKeys) {
        return this.decrypt(jwkKeys, null);
    }

    public SignedJwt decrypt(Jwk jwk) {
        return this.decrypt(null, jwk);
    }

    public SignedJwt decrypt(JwkKeys jwkKeys, Jwk defaultJwk) {
        AesAlgorithm aesAlgorithm;
        Errors.Collector errors = Errors.collector();
        String headerBase64 = EncryptedJwt.encode(this.header.headerJson().toString().getBytes(StandardCharsets.UTF_8));
        String alg = this.header.algorithm().orElse(null);
        String kid = this.header.keyId().orElse(null);
        String enc = this.header.encryption().orElse(null);
        Jwk jwk = null;
        String algorithm = null;
        if (kid != null) {
            if (jwkKeys != null) {
                jwk = jwkKeys.forKeyId(kid).orElse(null);
            } else if (kid.equals(defaultJwk.keyId())) {
                jwk = defaultJwk;
            } else {
                errors.fatal("Could not find JWK for kid: " + kid);
            }
        } else {
            jwk = defaultJwk;
            if (jwk == null) {
                errors.fatal("Could not find any suitable JWK.");
            }
        }
        if (enc == null) {
            errors.fatal("Content encryption algorithm not set.");
        }
        if (alg != null) {
            try {
                SupportedAlgorithm supportedAlgorithm = SupportedAlgorithm.getValue(alg);
                algorithm = RSA_ALGORITHMS.get((Object)supportedAlgorithm);
            }
            catch (IllegalArgumentException e) {
                errors.fatal("Value of the claim alg not supported. alg: " + alg);
            }
        } else {
            errors.fatal("No alg header was present among JWE headers");
        }
        PrivateKey privateKey = null;
        Jwk finalJwk = jwk;
        if (jwk instanceof JwkRSA) {
            privateKey = ((JwkRSA)jwk).privateKey().orElseGet(() -> {
                errors.fatal("No private key present in RSA JWK kid: " + finalJwk.keyId());
                return null;
            });
        } else if (jwk instanceof JwkEC) {
            privateKey = ((JwkEC)jwk).privateKey().orElseGet(() -> {
                errors.fatal("No private key present in EC JWK kid: " + finalJwk.keyId());
                return null;
            });
        } else if (jwk != null) {
            errors.fatal("Not supported JWK type: " + jwk.keyType() + ", JWK class: " + jwk.getClass().getName());
        } else {
            errors.fatal("No JWK found for key id: " + kid);
        }
        errors.collect().checkValid();
        byte[] decryptedKey = EncryptedJwt.decryptRsa(algorithm, privateKey, this.encryptedKey);
        EncryptionParts encryptionParts = new EncryptionParts(decryptedKey, this.iv, headerBase64.getBytes(StandardCharsets.US_ASCII), this.encryptedPayload, this.authTag);
        try {
            SupportedEncryption supportedEncryption = SupportedEncryption.getValue(enc);
            aesAlgorithm = CONTENT_ENCRYPTION.get((Object)supportedEncryption);
        }
        catch (IllegalArgumentException e) {
            throw new JwtException("Unsupported content encryption: " + enc);
        }
        String decryptedPayload = new String(aesAlgorithm.decrypt(encryptionParts), StandardCharsets.UTF_8);
        return SignedJwt.parseToken(decryptedPayload);
    }

    public JwtHeaders headers() {
        return this.header;
    }

    public String token() {
        return this.token;
    }

    public byte[] iv() {
        return Arrays.copyOf(this.iv, this.iv.length);
    }

    public byte[] encryptedKey() {
        return Arrays.copyOf(this.encryptedKey, this.encryptedKey.length);
    }

    public byte[] authTag() {
        return Arrays.copyOf(this.authTag, this.authTag.length);
    }

    public byte[] encryptedPayload() {
        return Arrays.copyOf(this.encryptedPayload, this.encryptedPayload.length);
    }

    static {
        JWE_PATTERN = Pattern.compile("(^[\\S]+)\\.([\\S]+)\\.([\\S]+)\\.([\\S]+)\\.([\\S]+$)");
        URL_DECODER = Base64.getUrlDecoder();
        URL_ENCODER = Base64.getUrlEncoder().withoutPadding();
        RSA_ALGORITHMS = Map.of(SupportedAlgorithm.RSA_OAEP, "RSA/ECB/OAEPWithSHA-1AndMGF1Padding", SupportedAlgorithm.RSA_OAEP_256, "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", SupportedAlgorithm.RSA1_5, "RSA/ECB/PKCS1Padding");
        CONTENT_ENCRYPTION = Map.of(SupportedEncryption.A128GCM, new AesGcmAlgorithm(128), SupportedEncryption.A192GCM, new AesGcmAlgorithm(192), SupportedEncryption.A256GCM, new AesGcmAlgorithm(256), SupportedEncryption.A128CBC_HS256, new AesAlgorithmWithHmac("AES/CBC/PKCS5Padding", 128, 16, "HmacSHA256"), SupportedEncryption.A192CBC_HS384, new AesAlgorithmWithHmac("AES/CBC/PKCS5Padding", 192, 16, "HmacSHA384"), SupportedEncryption.A256CBC_HS512, new AesAlgorithmWithHmac("AES/CBC/PKCS5Padding", 256, 16, "HmacSHA512"));
    }

    public static class Builder
    implements io.helidon.common.Builder<Builder, EncryptedJwt> {
        private final SignedJwt jwt;
        private final JwtHeaders.Builder headersBuilder = JwtHeaders.builder();
        private Jwk jwk;
        private SupportedAlgorithm algorithm = SupportedAlgorithm.RSA_OAEP;
        private SupportedEncryption encryption = SupportedEncryption.A256GCM;
        private JwkKeys jwks;
        private String kid;

        private Builder(SignedJwt jwt) {
            this.jwt = Objects.requireNonNull(jwt);
        }

        public Builder jwks(JwkKeys jwkKeys, String kid) {
            this.jwks = Objects.requireNonNull(jwkKeys);
            this.kid = Objects.requireNonNull(kid);
            return this;
        }

        public Builder jwk(Jwk jwk) {
            this.jwk = Objects.requireNonNull(jwk);
            return this;
        }

        public Builder algorithm(SupportedAlgorithm algorithm) {
            this.algorithm = Objects.requireNonNull(algorithm);
            return this;
        }

        public Builder encryption(SupportedEncryption encryption) {
            this.encryption = Objects.requireNonNull(encryption);
            return this;
        }

        public EncryptedJwt build() {
            PublicKey publicKey;
            this.headersBuilder.algorithm(this.algorithm.toString());
            this.headersBuilder.encryption(this.encryption.toString());
            this.headersBuilder.contentType("JWT");
            if (this.jwk == null && this.jwks != null) {
                this.jwk = this.jwks.forKeyId(this.kid).orElseThrow(() -> new JwtException("Could not determine which JWK should be used for encryption."));
                this.headersBuilder.keyId(this.kid);
            }
            if (this.jwk == null) {
                throw new JwtException("No JWK specified for encrypted JWT creation.");
            }
            if (this.jwk instanceof JwkRSA) {
                publicKey = ((JwkRSA)this.jwk).publicKey();
            } else if (this.jwk instanceof JwkEC) {
                publicKey = ((JwkEC)this.jwk).publicKey();
            } else {
                throw new JwtException("Unsupported JWK type: " + this.jwk.keyType());
            }
            JwtHeaders headers = this.headersBuilder.build();
            StringBuilder tokenBuilder = new StringBuilder();
            String headersBase64 = EncryptedJwt.encode(headers.headerJson().toString());
            String rsaCipherType = RSA_ALGORITHMS.get((Object)this.algorithm);
            AesAlgorithm contentEncryption = CONTENT_ENCRYPTION.get((Object)this.encryption);
            EncryptionParts encryptionParts = contentEncryption.encrypt(this.jwt.tokenContent().getBytes(StandardCharsets.UTF_8), headersBase64.getBytes(StandardCharsets.US_ASCII));
            byte[] aesKey = encryptionParts.key();
            byte[] encryptedAesKey = EncryptedJwt.encryptRsa(rsaCipherType, publicKey, aesKey);
            String token = tokenBuilder.append(headersBase64).append(".").append(EncryptedJwt.encode(encryptedAesKey)).append(".").append(EncryptedJwt.encode(encryptionParts.iv())).append(".").append(EncryptedJwt.encode(encryptionParts.encryptedContent())).append(".").append(EncryptedJwt.encode(encryptionParts.authTag())).toString();
            return new EncryptedJwt(token, headers, encryptionParts.iv, encryptedAesKey, encryptionParts.authTag(), encryptionParts.encryptedContent());
        }
    }

    public static enum SupportedAlgorithm {
        RSA_OAEP("RSA-OAEP"),
        RSA_OAEP_256("RSA-OAEP-256"),
        RSA1_5("RSA1_5");

        private final String algorithmName;

        private SupportedAlgorithm(String algorithmName) {
            this.algorithmName = algorithmName;
        }

        public String toString() {
            return this.algorithmName;
        }

        static SupportedAlgorithm getValue(String value) {
            for (SupportedAlgorithm v : SupportedAlgorithm.values()) {
                if (!v.algorithmName.equalsIgnoreCase(value)) continue;
                return v;
            }
            throw new IllegalArgumentException();
        }
    }

    private static final class EncryptionParts {
        private final byte[] key;
        private final byte[] iv;
        private final byte[] aad;
        private final byte[] encryptedContent;
        private final byte[] authTag;

        private EncryptionParts(byte[] key, byte[] iv, byte[] aad, byte[] encryptedContent, byte[] authTag) {
            this.key = key;
            this.iv = iv;
            this.aad = aad;
            this.encryptedContent = encryptedContent;
            this.authTag = authTag;
        }

        public byte[] key() {
            return this.key;
        }

        public byte[] iv() {
            return this.iv;
        }

        public byte[] aad() {
            return this.aad;
        }

        public byte[] encryptedContent() {
            return this.encryptedContent;
        }

        public byte[] authTag() {
            return this.authTag;
        }
    }

    public static enum SupportedEncryption {
        A128GCM("A128GCM"),
        A192GCM("A192GCM"),
        A256GCM("A256GCM"),
        A128CBC_HS256("A128CBC-HS256"),
        A192CBC_HS384("A192CBC-HS384"),
        A256CBC_HS512("A256CBC-HS512");

        private final String encryptionName;

        private SupportedEncryption(String encryptionName) {
            this.encryptionName = encryptionName;
        }

        public String toString() {
            return this.encryptionName;
        }

        static SupportedEncryption getValue(String value) {
            for (SupportedEncryption v : SupportedEncryption.values()) {
                if (!v.encryptionName.equalsIgnoreCase(value)) continue;
                return v;
            }
            throw new IllegalArgumentException();
        }
    }

    private static class AesAlgorithm {
        private static final SecureRandom RANDOM = new SecureRandom();
        private final String cipher;
        private final int keySize;
        private final int ivSize;

        private AesAlgorithm(String cipher, int keySize, int ivSize) {
            this.cipher = cipher;
            this.keySize = keySize;
            this.ivSize = ivSize;
        }

        EncryptionParts encrypt(byte[] plainContent, byte[] aad) {
            try {
                KeyGenerator kgen = KeyGenerator.getInstance("AES");
                kgen.init(this.keySize, RANDOM);
                SecretKey secretKey = kgen.generateKey();
                byte[] iv = new byte[this.ivSize];
                RANDOM.nextBytes(iv);
                EncryptionParts encryptionParts = new EncryptionParts(secretKey.getEncoded(), iv, aad, null, null);
                Cipher cipher = Cipher.getInstance(this.cipher);
                cipher.init(1, (Key)secretKey, this.createParameterSpec(encryptionParts));
                this.postCipherConstruct(cipher, encryptionParts);
                byte[] encryptedContent = cipher.doFinal(plainContent);
                return new EncryptionParts(secretKey.getEncoded(), iv, aad, encryptedContent, null);
            }
            catch (Exception e) {
                throw new JwtException("Exception during content encryption", e);
            }
        }

        byte[] decrypt(EncryptionParts encryptionParts) {
            try {
                byte[] key = encryptionParts.key();
                Cipher cipher = Cipher.getInstance(this.cipher);
                SecretKeySpec secretKey = new SecretKeySpec(key, "AES");
                cipher.init(2, (Key)secretKey, this.createParameterSpec(encryptionParts));
                this.postCipherConstruct(cipher, encryptionParts);
                byte[] encryptedContent = encryptionParts.encryptedContent();
                return cipher.doFinal(encryptedContent);
            }
            catch (Exception e) {
                throw new JwtException("Exception during content decryption.", e);
            }
        }

        protected void postCipherConstruct(Cipher cipher, EncryptionParts encryptionParts) {
        }

        protected AlgorithmParameterSpec createParameterSpec(EncryptionParts encryptionParts) {
            return new IvParameterSpec(encryptionParts.iv());
        }
    }

    private static class AesGcmAlgorithm
    extends AesAlgorithm {
        private AesGcmAlgorithm(int keySize) {
            super("AES/GCM/NoPadding", keySize, 12);
        }

        @Override
        public EncryptionParts encrypt(byte[] plainContent, byte[] aad) {
            EncryptionParts encryptionParts = super.encrypt(plainContent, aad);
            byte[] wholeEncryptedContent = encryptionParts.encryptedContent();
            int length = wholeEncryptedContent.length - 16;
            byte[] encryptedContent = new byte[length];
            byte[] authTag = new byte[16];
            System.arraycopy(wholeEncryptedContent, 0, encryptedContent, 0, encryptedContent.length);
            System.arraycopy(wholeEncryptedContent, length, authTag, 0, authTag.length);
            return new EncryptionParts(encryptionParts.key(), encryptionParts.iv(), encryptionParts.aad(), encryptedContent, authTag);
        }

        @Override
        byte[] decrypt(EncryptionParts encryptionParts) {
            byte[] encryptedPayload = encryptionParts.encryptedContent();
            byte[] authTag = encryptionParts.authTag();
            int epl = encryptedPayload.length;
            int al = authTag.length;
            byte[] result = new byte[epl + al];
            System.arraycopy(encryptedPayload, 0, result, 0, epl);
            System.arraycopy(authTag, 0, result, epl, al);
            EncryptionParts newEncParts = new EncryptionParts(encryptionParts.key(), encryptionParts.iv(), encryptionParts.aad(), result, authTag);
            return super.decrypt(newEncParts);
        }

        @Override
        protected AlgorithmParameterSpec createParameterSpec(EncryptionParts encryptionParts) {
            return new GCMParameterSpec(128, encryptionParts.iv());
        }

        @Override
        protected void postCipherConstruct(Cipher cipher, EncryptionParts encryptionParts) {
            cipher.updateAAD(encryptionParts.aad());
        }
    }

    private static class AesAlgorithmWithHmac
    extends AesAlgorithm {
        private final String hmac;

        private AesAlgorithmWithHmac(String cipher, int keySize, int ivSize, String hmac) {
            super(cipher, keySize, ivSize);
            this.hmac = hmac;
        }

        @Override
        public EncryptionParts encrypt(byte[] plainContent, byte[] aad) {
            EncryptionParts encryptionParts = super.encrypt(plainContent, aad);
            byte[] authTag = this.sign(encryptionParts);
            return new EncryptionParts(encryptionParts.key(), encryptionParts.iv(), encryptionParts.aad(), encryptionParts.encryptedContent(), authTag);
        }

        private byte[] sign(EncryptionParts parts) {
            try {
                Mac mac = this.macInstance();
                mac.init(new SecretKeySpec(parts.key(), "AES"));
                mac.update(parts.aad());
                mac.update(parts.encryptedContent());
                return mac.doFinal();
            }
            catch (InvalidKeyException e) {
                throw new JwtException("Exception occurred while HMAC signature");
            }
        }

        @Override
        public byte[] decrypt(EncryptionParts encryptionParts) {
            if (!this.verifySignature(encryptionParts)) {
                throw new JwtException("HMAC signature does not match");
            }
            return super.decrypt(encryptionParts);
        }

        private boolean verifySignature(EncryptionParts encryptionParts) {
            try {
                Mac mac = this.macInstance();
                mac.init(new SecretKeySpec(encryptionParts.key(), "AES"));
                mac.update(encryptionParts.aad());
                mac.update(encryptionParts.encryptedContent());
                byte[] authKey = mac.doFinal();
                return Arrays.equals(authKey, encryptionParts.authTag());
            }
            catch (InvalidKeyException e) {
                throw new JwtException("Exception occurred while HMAC signature.");
            }
        }

        private Mac macInstance() {
            try {
                return Mac.getInstance(this.hmac);
            }
            catch (NoSuchAlgorithmException e) {
                throw new JwtException("Could not find MAC instance: " + this.hmac);
            }
        }
    }
}

