/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.profile.impl;

import java.security.PrivateKey;
import java.security.interfaces.ECPrivateKey;
import java.text.ParseException;
import java.util.Iterator;
import java.util.function.Function;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.crypto.SecretKey;

import org.opensaml.messaging.context.navigate.ChildContextLookup;
import org.opensaml.profile.action.ActionSupport;
import org.opensaml.profile.action.EventIds;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.saml.saml2.profile.context.EncryptionContext;
import org.opensaml.security.credential.Credential;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEDecrypter;
import com.nimbusds.jose.crypto.AESDecrypter;
import com.nimbusds.jose.crypto.ECDHDecrypter;
import com.nimbusds.jose.crypto.RSADecrypter;
import com.nimbusds.jwt.EncryptedJWT;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTParser;

import net.shibboleth.idp.profile.context.RelyingPartyContext;
import net.shibboleth.oidc.profile.core.OidcEventIds;
import net.shibboleth.oidc.security.impl.OIDCDecryptionParameters;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;

/**
 * Action decrypts request object if it is encrypted. Decrypted object is updated to response context.
 */
public class DecryptRequestObject extends AbstractOIDCAuthenticationResponseAction {

    /** Class logger. */
    @Nonnull private Logger log = LoggerFactory.getLogger(DecryptRequestObject.class);

    /** Strategy used to look up the {@link EncryptionContext} to store parameters in. */
    @Nonnull private Function<ProfileRequestContext, EncryptionContext> encryptionContextLookupStrategy;

    /** Decryption parameters for decrypting payload. */
    @Nullable private OIDCDecryptionParameters params;

    /** Request Object. */
    @Nullable private JWT requestObject;

    /**
     * Constructor.
     */
    public DecryptRequestObject() {
        encryptionContextLookupStrategy = new ChildContextLookup<>(EncryptionContext.class).compose(
                new ChildContextLookup<>(RelyingPartyContext.class));
    }

    /**
     * Set the strategy used to look up the {@link EncryptionContext} to set the flags for.
     * 
     * @param strategy lookup strategy
     */
    public void setEncryptionContextLookupStrategy(
            @Nonnull final Function<ProfileRequestContext, EncryptionContext> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        encryptionContextLookupStrategy =
                Constraint.isNotNull(strategy, "EncryptionContext lookup strategy cannot be null");
    }

    /** {@inheritDoc} */
    @Override
    protected boolean doPreExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        if (!super.doPreExecute(profileRequestContext)) {
            return false;
        }
        
        requestObject = getOidcResponseContext().getRequestObject();
        if (requestObject == null) {
            log.debug("{} No request object, nothing to do", getLogPrefix());
            return false;
        }
        if (!(requestObject instanceof EncryptedJWT)) {
            log.debug("{} Request object not encrypted, nothing to do", getLogPrefix());
            return false;
        }
        // OIDC decryption parameters are set to stock shibboleth context as
        // EncryptionContex#getAttributeEncryptionParameters()
        final EncryptionContext encryptCtx = encryptionContextLookupStrategy.apply(profileRequestContext);
        if (encryptCtx == null
                || !(encryptCtx.getAttributeEncryptionParameters() instanceof OIDCDecryptionParameters)) {
            log.error(
                    "{} Encrypted request object but no EncryptionContext/OIDCDecryptionParameters available",
                    getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, EventIds.INVALID_SEC_CFG);
            return false;
        }
        params = (OIDCDecryptionParameters) encryptCtx.getAttributeEncryptionParameters();
        return true;
    }

    // Checkstyle: CyclomaticComplexity OFF

    /**
     * Decrypt request object.
     * 
     * @param encryptedObject request object to decrypt.
     * @return Decrypted request object. Null if decrypting failed.
     */
    private JWT decryptRequestObject(@Nonnull final EncryptedJWT encryptedObject) {
        if (!encryptedObject.getHeader().getAlgorithm().getName().equals(params.getKeyTransportEncryptionAlgorithm())) {
            log.error("{} Request object alg {} not matching expected {}", getLogPrefix(),
                    encryptedObject.getHeader().getAlgorithm().getName(), params.getKeyTransportEncryptionAlgorithm());
            return null;
        }
        if (!encryptedObject.getHeader().getEncryptionMethod().getName().equals(params.getDataEncryptionAlgorithm())) {
            log.error("{} Request object enc {} not matching expected {}", getLogPrefix(),
                    encryptedObject.getHeader().getEncryptionMethod().getName(), params.getDataEncryptionAlgorithm());
            return null;
        }
        final JWEAlgorithm encAlg = encryptedObject.getHeader().getAlgorithm();
        final Iterator<Credential> it = params.getKeyTransportDecryptionCredentials().iterator();
        while (it.hasNext()) {
            final Credential credential = it.next();
            JWEDecrypter decrypter = null;
            try {
                if (JWEAlgorithm.Family.RSA.contains(encAlg)) {
                    decrypter = new RSADecrypter((PrivateKey) credential.getPrivateKey());
                }
                if (JWEAlgorithm.Family.ECDH_ES.contains(encAlg)) {
                    decrypter = new ECDHDecrypter((ECPrivateKey) credential.getPrivateKey());
                }
                if (JWEAlgorithm.Family.AES_GCM_KW.contains(encAlg) || JWEAlgorithm.Family.AES_KW.contains(encAlg)) {
                    decrypter = new AESDecrypter((SecretKey) credential.getSecretKey());
                }
                if (decrypter == null) {
                    log.error("{} No decrypter for request object for encAlg {}", getLogPrefix(),
                            encryptedObject.getHeader().getEncryptionMethod().getName());
                    return null;
                }
                encryptedObject.decrypt(decrypter);
                return JWTParser.parse(encryptedObject.getPayload().toString());
            } catch (final JOSEException | ParseException e) {
                if (it.hasNext()) {
                    log.debug("{} Unable to decrypt request object with credential, {}, picking next key",
                            getLogPrefix(), e.getMessage());
                } else {
                    log.error("{} Unable to decrypt request object with credential, {}", getLogPrefix(),
                            e.getMessage());
                    return null;
                }
            }
        }
        // Should never come here
        return null;
    }
    
    // Checkstyle: CyclomaticComplexity ON

    /** {@inheritDoc} */
    @Override
    protected void doExecute(@Nonnull final ProfileRequestContext profileRequestContext) {

        requestObject = decryptRequestObject((EncryptedJWT) requestObject);
        if (requestObject == null) {
            log.error("{} Unable to decrypt request object", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_REQUEST_OBJECT);
            return;
        }
        
        // Let's update decrypted request object back to response context
        getOidcResponseContext().setRequestObject(requestObject);
        log.debug("{} Request object decrypted as {}", getLogPrefix(),
                getOidcResponseContext().getRequestObject().serialize());
    }

}