/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.userinfo.profile.impl;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collection;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.opensaml.profile.action.ActionSupport;
import org.opensaml.profile.action.EventIds;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialResolver;
import org.opensaml.security.credential.UsageType;
import org.opensaml.security.criteria.UsageCriterion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.oauth2.sdk.token.AccessToken; 

import net.shibboleth.idp.plugin.oidc.op.messaging.context.OIDCAuthenticationResponseContext;
import net.shibboleth.idp.plugin.oidc.op.token.support.AccessTokenClaimsSet;
import net.shibboleth.oidc.profile.core.OidcEventIds;
import net.shibboleth.oidc.security.impl.JWTSignatureValidationUtil;
import net.shibboleth.utilities.java.support.annotation.constraint.NonnullAfterInit;
import net.shibboleth.utilities.java.support.annotation.constraint.NotEmpty;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import net.shibboleth.utilities.java.support.resolver.ResolverException;
import net.shibboleth.utilities.java.support.security.DataSealer;
import net.shibboleth.utilities.java.support.security.DataSealerException;

/**
 * Action that parses an access token and initially populates the claims for later
 * validation.
 * 
 * <p>Signed JWTs are also signature-checked here.</p>
 * 
 * <p>The parsed token is stored to the response context retrievable as claims via
 * {@link OIDCAuthenticationResponseContext#getTokenClaimsSet()}. Claims validation takes
 * place later in order to allow for metadata and relying-party/profile config
 * lookup to allow for pluggable validation, an overridden OP/issuer name, etc.</p>
 * 
 * @event {@link EventIds#PROCEED_EVENT_ID}
 * @event {@link OidcEventIds#INVALID_GRANT}
 * 
 * @since 3.2.0
 */
public class ParseAccessToken extends AbstractOIDCUserInfoValidationResponseAction {

    /** Class logger. */
    @Nonnull private Logger log = LoggerFactory.getLogger(ParseAccessToken.class);

    /** Data sealer for unwrapping authorization code. */
    @NonnullAfterInit private DataSealer dataSealer;
    
    /** Source of signing keys. */
    @Nullable private CredentialResolver credentialResolver;
    
    /** Copy of signed JWT for non-opaque access tokens. */
    @Nullable private SignedJWT signedJWT;
    
    /**
     * Set the data sealer instance to use.
     * 
     * @param sealer sealer to use
     */
    public void setDataSealer(@Nonnull final DataSealer sealer) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        dataSealer = Constraint.isNotNull(sealer, "DataSealer cannot be null");
    }
    
    /**
     * Set the source of signing keys to use for JWT signature verification.
     * 
     * @param resolver signing key resolver
     */
    public void setCredentialResolver(@Nullable final CredentialResolver resolver) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        credentialResolver = resolver;
    }
    
    /** {@inheritDoc} */
    @Override
    protected void doInitialize() throws ComponentInitializationException {
        super.doInitialize();
        
        if (dataSealer == null) {
            throw new ComponentInitializationException("DataSealer cannot be null");
        }
    }

 // Checkstyle: CyclomaticComplexity OFF

    /** {@inheritDoc} */
    @Override
    protected void doExecute(@Nonnull final ProfileRequestContext profileRequestContext) {

        final AccessToken token = getUserInfoRequest().getAccessToken();
        if (token == null) {
            log.error("{} Token missing from request", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
            return;
        }
        
        final AccessTokenClaimsSet accessTokenClaimsSet = parseAccessToken(token);
        if (accessTokenClaimsSet == null) {
            log.warn("{} Unable to parse/decode token for validation", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
            return;
        }
        
        log.debug("{} Access token unwrapped: {}", getLogPrefix(), accessTokenClaimsSet.serialize());
        
        if (signedJWT != null) {
            // Check typ header.
            final JOSEObjectType typ = signedJWT.getHeader().getType();
            if (typ == null || !"at+jwt".equals(typ.getType())) {
                log.warn("{} Missing or invalid token type: {}", getLogPrefix(), typ != null ? typ.getType() : "null");
                ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
                return;
            }
            
            if (credentialResolver == null) {
                log.error("{} No CredentialResolver available, can't verify JWT signature", getLogPrefix());
                ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
                return;
            }
            
            log.debug("{} Checking JWT signature", getLogPrefix());
            final Collection<Credential> credList = new ArrayList<>();
            final CriteriaSet criteriaSet = new CriteriaSet(new UsageCriterion(UsageType.SIGNING));
            try {
                final Iterable<Credential> creds = credentialResolver.resolve(criteriaSet);
                if (creds != null) {
                    creds.forEach(credList::add);
                }
            } catch (final ResolverException e) {
                log.error("{} Failure resolving signing credentials, can't verify JWT signature", getLogPrefix(), e);
                ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
                return;
            }
            final String errorEventId = JWTSignatureValidationUtil.validateSignatureEx(credList, signedJWT,
                    OidcEventIds.INVALID_GRANT);
            if (errorEventId != null) {
                log.warn("{} Signature on token ID '{}' invalid", getLogPrefix(), accessTokenClaimsSet.getID());
                ActionSupport.buildEvent(profileRequestContext, errorEventId);
                return;
            }
        }

        log.debug("{} Access token {} parsed", getLogPrefix(), accessTokenClaimsSet.getID());
        getOidcResponseContext().setAuthorizationGrantClaimsSet(accessTokenClaimsSet);
    }

 // Checkstyle: CyclomaticComplexity ON

    /**
     * Attempt to parse token.
     * 
     * @param token the token
     * 
     * @return parsed claim set or null
     */
    @Nullable protected AccessTokenClaimsSet parseAccessToken(@Nonnull @NotEmpty final AccessToken token) {
        
        // Try parsing as a JWT.
        try {
            signedJWT = SignedJWT.parse(token.getValue());
            return AccessTokenClaimsSet.parse(signedJWT, dataSealer);
        } catch (final DataSealerException | ParseException e) {
            
        }

        // Fall back to opaque.
        try {
            return AccessTokenClaimsSet.parse(token.getValue(), dataSealer);
        } catch (final DataSealerException | ParseException e) {
            
        }
        
        return null;
    }

}