/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.security.authc.jwt;

import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.crypto.MACSigner;
import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyOperation;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import java.nio.charset.StandardCharsets;
import java.security.PublicKey;
import java.security.interfaces.RSAPublicKey;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.settings.SettingsException;
import org.elasticsearch.core.Strings;
import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
import org.elasticsearch.xpack.security.authc.jwt.JwkSetLoader;
import org.elasticsearch.xpack.security.authc.jwt.JwtUtil;

public class JwkValidateUtil {
    private static final Logger logger = LogManager.getLogger(JwkValidateUtil.class);

    static JwkSetLoader.JwksAlgs filterJwksAndAlgorithms(List<JWK> jwks, List<String> algs) throws SettingsException {
        try (JwtUtil.TraceBuffer tracer = new JwtUtil.TraceBuffer(logger);){
            tracer.append("Filtering [{}] JWKs for the following algorithms [{}].", jwks.size(), String.join((CharSequence)",", algs));
            Predicate<JWK> keyUsePredicate = j -> j.getKeyUse() == null || KeyUse.SIGNATURE.equals((Object)j.getKeyUse());
            List<JWK> jwksSig = jwks.stream().filter(keyUsePredicate).toList();
            tracer.append("[{}] remaining JWKs after KeyUse [SIGNATURE] filter.", jwksSig.size());
            Predicate<JWK> keyOpPredicate = j -> j.getKeyOperations() == null || j.getKeyOperations().contains(KeyOperation.VERIFY);
            List<JWK> jwksVerify = jwksSig.stream().filter(keyOpPredicate).toList();
            tracer.append("[{}] remaining JWKs after KeyOperation [VERIFY] filter.", jwksVerify.size());
            List<JWK> jwksFiltered = jwksVerify.stream().filter(j -> algs.stream().anyMatch(a -> JwkValidateUtil.isMatch(j, a, tracer))).toList();
            tracer.append("[{}] remaining JWKs after algorithms name filter.", jwksFiltered.size());
            List<String> algsFiltered = algs.stream().filter(a -> jwksFiltered.stream().anyMatch(j -> JwkValidateUtil.isMatch(j, a, tracer))).toList();
            tracer.append("[{}] remaining JWKs after configured algorithms [{}] filter.", jwksFiltered.size(), String.join((CharSequence)",", algsFiltered));
            JwkSetLoader.JwksAlgs jwksAlgs = new JwkSetLoader.JwksAlgs(jwksFiltered, algsFiltered);
            return jwksAlgs;
        }
    }

    static boolean isMatch(JWK jwk, String algorithm, JwtUtil.TraceBuffer tracer) {
        try {
            if (JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(algorithm) && jwk instanceof OctetSequenceKey) {
                int min;
                boolean isMatch;
                OctetSequenceKey jwkHmac = (OctetSequenceKey)jwk;
                int bits = jwkHmac.size();
                boolean bl = isMatch = bits >= (min = MACSigner.getMinRequiredSecretLength((JWSAlgorithm)JWSAlgorithm.parse((String)algorithm)));
                if (!isMatch) {
                    tracer.append("HMAC JWK [" + bits + "] bits too small for algorithm [" + algorithm + "] minimum [" + min + "].", new Object[0]);
                }
                return isMatch;
            }
            if (JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_RSA.contains(algorithm) && jwk instanceof RSAKey) {
                boolean isMatch;
                RSAKey jwkRsa = (RSAKey)jwk;
                int bits = JwkValidateUtil.computeBitLengthRsa(jwkRsa.toPublicKey());
                int min = 2048;
                boolean bl = isMatch = bits >= 2048;
                if (!isMatch) {
                    tracer.append("RSA JWK [" + bits + "] bits too small for algorithm [" + algorithm + "] minimum [2048].", new Object[0]);
                }
                return isMatch;
            }
            if (JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_EC.contains(algorithm) && jwk instanceof ECKey) {
                ECKey jwkEc = (ECKey)jwk;
                Curve curve = jwkEc.getCurve();
                Set allowed = Curve.forJWSAlgorithm((JWSAlgorithm)JWSAlgorithm.parse((String)algorithm));
                boolean isMatch = allowed.contains(curve);
                if (!isMatch) {
                    tracer.append("EC JWK [" + curve + "] curve not allowed for algorithm [" + algorithm + "] allowed " + allowed + ".", new Object[0]);
                }
                return isMatch;
            }
        }
        catch (Exception e) {
            Supplier<String> message = () -> Strings.format((String)"Unexpected exception while matching JWK with kid [%s] to it's algorithm requirement.", (Object[])new Object[]{jwk.getKeyID()});
            if (logger.isTraceEnabled()) {
                logger.trace(message.get(), (Throwable)e);
            }
            logger.debug(message.get());
        }
        return false;
    }

    static int computeBitLengthRsa(PublicKey publicKey) throws Exception {
        if (publicKey instanceof RSAPublicKey) {
            RSAPublicKey rsaPublicKey = (RSAPublicKey)publicKey;
            int modulusBigLength = rsaPublicKey.getModulus().bitLength();
            return (modulusBigLength + 8 - 1) / 8 * 8;
        }
        if (publicKey == null) {
            throw new Exception("Expected public key class [RSAPublicKey]. Got [null] instead.");
        }
        throw new Exception("Expected public key class [RSAPublicKey]. Got [" + publicKey.getClass().getSimpleName() + "] instead.");
    }

    static List<JWK> loadJwksFromJwkSetString(String jwkSetConfigKey, CharSequence jwkSetContents) throws SettingsException {
        if (org.elasticsearch.common.Strings.hasText((CharSequence)jwkSetContents)) {
            try {
                return JWKSet.parse((String)jwkSetContents.toString()).getKeys();
            }
            catch (Exception e) {
                throw new SettingsException("JWKSet parse failed for setting [" + jwkSetConfigKey + "]", (Throwable)e);
            }
        }
        return Collections.emptyList();
    }

    static OctetSequenceKey loadHmacJwkFromJwkString(String jwkSetConfigKey, CharSequence hmacKeyContents) {
        if (org.elasticsearch.common.Strings.hasText((CharSequence)hmacKeyContents)) {
            try {
                return JwkValidateUtil.buildHmacKeyFromString(hmacKeyContents);
            }
            catch (Exception e) {
                throw new SettingsException("HMAC Key parse failed for setting [" + jwkSetConfigKey + "]", (Throwable)e);
            }
        }
        return null;
    }

    static OctetSequenceKey buildHmacKeyFromString(CharSequence hmacKeyContents) {
        String hmacKeyString = hmacKeyContents.toString();
        byte[] utf8Bytes = hmacKeyString.getBytes(StandardCharsets.UTF_8);
        return new OctetSequenceKey.Builder(utf8Bytes).build();
    }
}

