package com.atlassian.asap.nimbus.serializer;

import java.security.PrivateKey;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.RSAPrivateKey;

import com.atlassian.asap.api.AlgorithmType;
import com.atlassian.asap.api.Jwt;
import com.atlassian.asap.api.JwtClaims;
import com.atlassian.asap.api.JwtClaims.Claim;
import com.atlassian.asap.api.SigningAlgorithm;
import com.atlassian.asap.core.exception.SigningException;
import com.atlassian.asap.core.exception.UnsupportedAlgorithmException;
import com.atlassian.asap.core.serializer.JwtSerializer;

import com.google.common.annotations.VisibleForTesting;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.Payload;
import com.nimbusds.jose.crypto.ECDSASigner;
import com.nimbusds.jose.crypto.RSASSASigner;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.minidev.json.JSONObject;

/**
 * A serializer of JWT implemented using the Nimbus library.
 */
public class NimbusJwtSerializer implements JwtSerializer
{
    private static final Logger logger = LoggerFactory.getLogger(NimbusJwtSerializer.class);

    @Override
    public String serialize(Jwt jwt, PrivateKey privateKey) throws SigningException, UnsupportedAlgorithmException
    {
        JWSObject jwsObject = getSignedJwsObject(jwt, privateKey);

        return jwsObject.serialize();
    }

    @VisibleForTesting
    JWSObject getSignedJwsObject(Jwt jwt, PrivateKey privateKey) throws UnsupportedAlgorithmException
    {
        SigningAlgorithm algorithm = jwt.getHeader().getAlgorithm();
        JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.parse(algorithm.name())) // fails if algorithm is None
                .keyID(jwt.getHeader().getKeyId())
                .build();
        Payload payload = new Payload(toJsonPayload(jwt.getClaims()));
        JWSObject jwsObject = new JWSObject(header, payload);
        try
        {
            jwsObject.sign(getSigner(algorithm, privateKey));
        }
        catch (JOSEException e)
        {
            logger.error("Unexpected error when signing JWT token", e);
            throw new SigningException();
        }
        return jwsObject;
    }

    private JWSSigner getSigner(SigningAlgorithm algorithm, PrivateKey privateKey) throws UnsupportedAlgorithmException
    {
        if ((algorithm.type() == AlgorithmType.RSA || algorithm.type() == AlgorithmType.RSASSA_PSS) && privateKey instanceof RSAPrivateKey)
        {
            return createRSASSASignerForKey((RSAPrivateKey) privateKey);
        }
        else if (algorithm.type() == AlgorithmType.ECDSA && privateKey instanceof ECPrivateKey)
        {
            return createECDSASignerForKey((ECPrivateKey) privateKey);
        }
        else
        {
            throw new UnsupportedAlgorithmException(String.format("Unsupported algorithm %s or signing key type", algorithm.name()));
        }
    }

    @VisibleForTesting
    protected JWSSigner createRSASSASignerForKey(RSAPrivateKey privateKey)
    {
        return new RSASSASigner(privateKey);
    }

    @VisibleForTesting
    protected JWSSigner createECDSASignerForKey(ECPrivateKey privateKey)
    {
        return new ECDSASigner(privateKey.getS());
    }

    private static JSONObject toJsonPayload(JwtClaims claims)
    {
        JSONObject claimsMap = new JSONObject();
        claimsMap.put(Claim.ISSUER.key(), claims.getIssuer());

        if (claims.getSubject().isPresent())
        {
            claimsMap.put(Claim.SUBJECT.key(), claims.getSubject().get());
        }

        // optimisation: if the audience is a singleton, send it as a single value, otherwise send it as multivalued
        if (claims.getAudience().size() == 1)
        {
            claimsMap.put(Claim.AUDIENCE.key(), claims.getAudience().iterator().next());
        }
        else
        {
            claimsMap.put(Claim.AUDIENCE.key(), claims.getAudience());
        }

        claimsMap.put(Claim.JWT_ID.key(), claims.getJwtId());
        claimsMap.put(Claim.ISSUED_AT.key(), claims.getIssuedAt().getEpochSecond());
        claimsMap.put(Claim.EXPIRY.key(), claims.getExpiry().getEpochSecond());

        if (claims.getNotBefore().isPresent())
        {
            claimsMap.put(Claim.NOT_BEFORE.key(), claims.getNotBefore().get().getEpochSecond());
        }

        return claimsMap;
    }
}
