/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.profile.impl;

import java.security.interfaces.ECPrivateKey;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.opensaml.profile.action.ActionSupport;
import org.opensaml.profile.action.EventIds;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.security.credential.Credential;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.ECDSASigner;
import com.nimbusds.jose.crypto.MACSigner;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;

import net.shibboleth.oidc.security.credential.JWKCredential;
import net.shibboleth.oidc.security.impl.CredentialConversionUtil;
import net.shibboleth.utilities.java.support.annotation.constraint.NotEmpty;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.primitive.StringSupport;

/**
 * Abstract action for signing JWT. The extending class is expected to set claims set by implementing
 * {@link #getClaimsSetToSign}. The signed jwt is received by extending class by implementing method
 * {@link #setSignedJWT}.
 */
public abstract class AbstractSignJWTAction extends AbstractOIDCSigningResponseAction {

    /** Class logger. */
    @Nonnull private Logger log = LoggerFactory.getLogger(AbstractSignJWTAction.class);

    /** "typ" header to insert while signing. */
    @Nullable @NotEmpty private String typeHeader;

    /** resolved credential. */
    @Nullable private Credential credential;

    /**
     * Sets the value to be inserted as a "typ" header for the JWS.
     * 
     * @param type header value
     * 
     * @since 3.1.0
     */
    public void setTypeHeader(@Nullable @NotEmpty final String type) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        typeHeader = StringSupport.trimOrNull(type);
    }
    
    /** {@inheritDoc} */
    @Override
    protected boolean doPreExecute(@Nonnull final ProfileRequestContext profileRequestContext) {

        if (!super.doPreExecute(profileRequestContext)) {
            return false;
        }
        
        credential = getSignatureSigningParameters().getSigningCredential();
        return true;
    }

    /**
     * Returns correct implementation of signer based on algorithm type.
     * 
     * @param jwsAlgorithm JWS algorithm
     * @return signer for algorithm and private key
     * @throws JOSEException if algorithm cannot be supported
     */
    private JWSSigner getSigner(final Algorithm jwsAlgorithm) throws JOSEException {
        if (JWSAlgorithm.Family.EC.contains(jwsAlgorithm)) {
            return new ECDSASigner((ECPrivateKey) credential.getPrivateKey());
        }
        if (JWSAlgorithm.Family.RSA.contains(jwsAlgorithm)) {
            return new RSASSASigner(credential.getPrivateKey());
        }
        if (JWSAlgorithm.Family.HMAC_SHA.contains(jwsAlgorithm)) {
            return new MACSigner(credential.getSecretKey());
        }
        throw new JOSEException("Unsupported algorithm " + jwsAlgorithm.getName());
    }

    /**
     * Resolves JWS algorithm from signature signing parameters.
     * 
     * @return JWS algorithm
     */
    protected JWSAlgorithm resolveAlgorithm() {

        final JWSAlgorithm algorithm = new JWSAlgorithm(getSignatureSigningParameters().getSignatureAlgorithm());
        if (credential instanceof JWKCredential) {
            if (!algorithm.equals(((JWKCredential) credential).getAlgorithm())) {
                log.debug("{} Signature signing algorithm {} differs from JWK algorithm {}", getLogPrefix(),
                        algorithm.getName(), ((JWKCredential) credential).getAlgorithm());
            }
        }
        log.debug("{} Algorithm resolved {}", getLogPrefix(), algorithm.getName());
        return algorithm;
    }

    /**
     * Called with signed JWT as parameter.
     * 
     * @param jwt signed JWT.
     */
    protected abstract void setSignedJWT(@Nullable SignedJWT jwt);

    /**
     * Called to get claim set to sign.
     * 
     * @return claim set to sign
     */
    protected abstract @Nonnull JWTClaimsSet getClaimsSetToSign();

    /** {@inheritDoc} */
    @Override
    protected void doExecute(@Nonnull final ProfileRequestContext profileRequestContext) {

        SignedJWT jwt = null;
        final JWTClaimsSet jwtClaimSet = getClaimsSetToSign();
        if (jwtClaimSet == null) {
            log.debug("Claim set is null, nothing to do");
            return;
        }
        try {
            final Algorithm jwsAlgorithm = resolveAlgorithm();
            final JWSSigner signer = getSigner(jwsAlgorithm);
            final JWSHeader.Builder headerBuilder = new JWSHeader.Builder(new JWSAlgorithm(jwsAlgorithm.getName()))
                    .keyID(CredentialConversionUtil.resolveKid(credential));
            if (typeHeader != null) {
                headerBuilder.type(new JOSEObjectType(typeHeader));
            }
            jwt = new SignedJWT(headerBuilder.build(), jwtClaimSet);
            jwt.sign(signer);
        } catch (final JOSEException e) {
            log.error("{} Error signing claim set: {}", getLogPrefix(), e.getMessage());
            ActionSupport.buildEvent(profileRequestContext, EventIds.UNABLE_TO_SIGN);
            return;
        }
        setSignedJWT(jwt);
    }

}