/*
 * Decompiled with CFR 0.152.
 */
package io.helidon.security.jwt;

import io.helidon.common.Errors;
import io.helidon.security.jwt.Jwt;
import io.helidon.security.jwt.JwtException;
import io.helidon.security.jwt.JwtUtil;
import io.helidon.security.jwt.jwk.Jwk;
import io.helidon.security.jwt.jwk.JwkKeys;
import java.io.Reader;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Base64;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.json.Json;
import javax.json.JsonObject;

public class SignedJwt {
    private static final Pattern JWT_PATTERN = Pattern.compile("([a-zA-Z0-9/=+]+)\\.([a-zA-Z0-9/=+]+)\\.([a-zA-Z0-9_\\-/=+]*)");
    private static final Base64.Decoder URL_DECODER = Base64.getUrlDecoder();
    private static final Base64.Encoder URL_ENCODER = Base64.getUrlEncoder();
    private final String tokenContent;
    private final JsonObject headerJson;
    private final JsonObject payloadJson;
    private final byte[] signedBytes;
    private final byte[] signature;

    private SignedJwt(String tokenContent, JsonObject headerJson, JsonObject payloadJson, byte[] signedBytes, byte[] signature) {
        this.tokenContent = tokenContent;
        this.headerJson = headerJson;
        this.payloadJson = payloadJson;
        this.signedBytes = signedBytes;
        this.signature = signature;
    }

    public static SignedJwt sign(Jwt jwt, JwkKeys jwks) throws JwtException {
        return jwt.getAlgorithm().map(alg -> SignedJwt.sign(jwt, jwks, alg)).orElseGet(() -> jwt.getKeyId().map(kid -> jwks.forKeyId((String)kid).map(jwk -> SignedJwt.sign(jwt, jwk)).orElseThrow(() -> new JwtException("Could not find JWK based on key id. JWT: " + jwt + ", kid: " + kid))).orElseGet(() -> SignedJwt.sign(jwt, Jwk.NONE_JWK)));
    }

    public static SignedJwt sign(Jwt jwt, Jwk jwk) throws JwtException {
        JsonObject headerJson = jwt.getHeaderJson();
        JsonObject payloadJson = jwt.getPayloadJson();
        String headerJsonString = headerJson.toString();
        String payloadJsonString = payloadJson.toString();
        String headerBase64 = SignedJwt.encode(headerJsonString);
        String payloadBase64 = SignedJwt.encode(payloadJsonString);
        String signedString = headerBase64 + '.' + payloadBase64;
        byte[] signedBytes = signedString.getBytes(StandardCharsets.UTF_8);
        byte[] signature = jwk.sign(signedBytes);
        String signatureBase64 = SignedJwt.encode(signature);
        String tokenContent = signedString + '.' + signatureBase64;
        return new SignedJwt(tokenContent, headerJson, payloadJson, signedBytes, signature);
    }

    private static SignedJwt sign(Jwt jwt, JwkKeys jwks, String alg) {
        Jwk jwk = jwt.getKeyId().map(kid -> jwks.forKeyId((String)kid).orElseThrow(() -> new JwtException("Could not find JWK for kid: " + kid))).orElseGet(() -> {
            if ("none".equals(alg)) {
                return Jwk.NONE_JWK;
            }
            throw new JwtException("JWT defined with signature algorithm " + alg + ", yet no key id (kid): " + jwt);
        });
        return SignedJwt.sign(jwt, jwk);
    }

    public static SignedJwt parseToken(String tokenContent) {
        Errors.Collector collector = Errors.collector();
        Matcher matcher = JWT_PATTERN.matcher(tokenContent);
        if (matcher.matches()) {
            String headerBase64 = matcher.group(1);
            String payloadBase64 = matcher.group(2);
            String signatureBase64 = matcher.group(3);
            String headerJsonString = SignedJwt.decode(headerBase64, collector, "JWT header");
            String payloadJsonString = SignedJwt.decode(payloadBase64, collector, "JWT payload");
            byte[] signatureBytes = SignedJwt.decodeBytes(signatureBase64, collector, "JWT signature");
            collector.collect().checkValid();
            String signedContent = headerBase64 + '.' + payloadBase64;
            JsonObject headerJson = SignedJwt.parseJson(headerJsonString, collector, "JWT header");
            JsonObject contentJson = SignedJwt.parseJson(payloadJsonString, collector, "JWT payload");
            collector.collect().checkValid();
            return new SignedJwt(tokenContent, headerJson, contentJson, signedContent.getBytes(StandardCharsets.UTF_8), signatureBytes);
        }
        throw new JwtException("Not a JWT token: " + tokenContent);
    }

    private static JsonObject parseJson(String jsonString, Errors.Collector collector, String description) {
        try {
            return Json.createReader((Reader)new StringReader(jsonString)).readObject();
        }
        catch (Exception e) {
            collector.fatal((Object)jsonString, description + " is not a valid JSON object");
            return null;
        }
    }

    private static String encode(String string) {
        return SignedJwt.encode(string.getBytes(StandardCharsets.UTF_8));
    }

    private static String encode(byte[] bytes) {
        return URL_ENCODER.encodeToString(bytes);
    }

    private static String decode(String base64, Errors.Collector collector, String description) {
        try {
            return new String(URL_DECODER.decode(base64), StandardCharsets.UTF_8);
        }
        catch (Exception e) {
            collector.fatal((Object)base64, description + " is not a base64 encoded string.");
            return null;
        }
    }

    private static byte[] decodeBytes(String base64, Errors.Collector collector, String description) {
        try {
            return URL_DECODER.decode(base64);
        }
        catch (Exception e) {
            collector.fatal((Object)base64, description + " is not a base64 encoded string.");
            return null;
        }
    }

    public String getTokenContent() {
        return this.tokenContent;
    }

    JsonObject getHeaderJson() {
        return this.headerJson;
    }

    JsonObject getPayloadJson() {
        return this.payloadJson;
    }

    public byte[] getSignedBytes() {
        return Arrays.copyOf(this.signedBytes, this.signedBytes.length);
    }

    public byte[] getSignature() {
        return Arrays.copyOf(this.signature, this.signature.length);
    }

    public Jwt getJwt() {
        return new Jwt(this.headerJson, this.payloadJson);
    }

    public Errors verifySignature(JwkKeys keys) {
        Jwk jwk;
        Errors.Collector collector = Errors.collector();
        String alg = JwtUtil.getString(this.headerJson, "alg").orElse(null);
        String kid = JwtUtil.getString(this.headerJson, "kid").orElse(null);
        if (null == alg) {
            if (null == kid) {
                collector.warn("Neither alg nor kid are specified in JWT, assuming none algorithm");
                jwk = Jwk.NONE_JWK;
                alg = jwk.getAlgorithm();
            } else {
                jwk = keys.forKeyId(kid).orElse(null);
                if (null == jwk) {
                    collector.fatal((Object)keys, "Key for key id: " + kid + " not found");
                } else {
                    alg = jwk.getAlgorithm();
                }
            }
        } else if (null == kid) {
            if ("none".equals(alg)) {
                jwk = Jwk.NONE_JWK;
            } else {
                collector.fatal("Algorithm is " + alg + ", yet no kid is defined in JWT header, cannot validate");
                jwk = null;
            }
        } else {
            jwk = keys.forKeyId(kid).orElse(null);
            if (null == jwk) {
                collector.fatal((Object)keys, "Key for key id: " + kid + " not found");
            }
        }
        if (null == jwk) {
            return collector.collect();
        }
        if (jwk.getAlgorithm().equals(alg)) {
            if (!jwk.verifySignature(this.signedBytes, this.signature)) {
                collector.fatal((Object)jwk, "Signature of JWT token is not valid, based on alg: " + alg + ", kid: " + kid);
            }
        } else {
            collector.fatal((Object)jwk, "Algorithm of JWK (" + jwk.getAlgorithm() + ") does not match algorithm of this JWT (" + alg + ") for kid: " + kid);
        }
        return collector.collect();
    }
}

