/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.userinfo.profile.impl;

import java.util.function.Function;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.opensaml.profile.action.ActionSupport;
import org.opensaml.profile.action.EventIds;
import org.opensaml.profile.context.ProfileRequestContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.shibboleth.idp.plugin.oidc.op.messaging.context.OIDCAuthenticationResponseContext;
import net.shibboleth.idp.plugin.oidc.op.token.support.AccessTokenClaimsSet;
import net.shibboleth.idp.plugin.oidc.op.token.support.TokenClaimsSet;
import net.shibboleth.idp.profile.IdPEventIds;
import net.shibboleth.oidc.jwt.claims.ClaimsValidator;
import net.shibboleth.oidc.jwt.claims.JWTValidationException;
import net.shibboleth.oidc.profile.config.navigate.IssuedClaimsValidatorLookupFunction;
import net.shibboleth.oidc.profile.core.OidcEventIds;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;

/**
 * Action that validates the claims pulled from an access token as usable for access
 * to the OP's UserInfo endpoint.
 * 
 * <p>The parsed claims are pulled from
 * {@link OIDCAuthenticationResponseContext#getAuthorizationGrantClaimsSet()}.</p>
 *
 * @event {@link EventIds#PROCEED_EVENT_ID}
 * @event {@link IdPEventIds#INVALID_PROFILE_CONFIG}
 * @event {@link OidcEventIds#INVALID_GRANT}
 */
public class ValidateAccessToken extends AbstractOIDCUserInfoValidationResponseAction {

    /** Class logger. */
    @Nonnull private Logger log = LoggerFactory.getLogger(ValidateAccessToken.class);
    
    /** Lookup strategy for claims validator. */
    @Nonnull private Function<ProfileRequestContext,ClaimsValidator> claimsValidatorLookupStrategy;
    
    /** The claims validator to use. */
    @Nullable private ClaimsValidator claimsValidator;

    /** Constructor. */
    public ValidateAccessToken() {
        claimsValidatorLookupStrategy = new IssuedClaimsValidatorLookupFunction();
    }

    /**
     * Set the claims validator lookup strategy.
     * 
     * @param strategy lookup strategy
     */
    public void setClaimsValidatorLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,ClaimsValidator> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        claimsValidatorLookupStrategy = Constraint.isNotNull(strategy, "Lookup strategy cannot be null");
    }
        
    /** {@inheritDoc} */
    @Override
    protected boolean doPreExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        if (!super.doPreExecute(profileRequestContext)) {
            return false;
        }
        
        claimsValidator = claimsValidatorLookupStrategy.apply(profileRequestContext);
        if (claimsValidator == null) {
            log.error("{} Unable to obtain ClaimsValidator to apply", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, IdPEventIds.INVALID_PROFILE_CONFIG);
            return false;
        }
        
        return true;
    }
    
    
    /** {@inheritDoc} */
    @Override
    protected void doExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        
        final TokenClaimsSet tokenClaims = getOidcResponseContext().getAuthorizationGrantClaimsSet();
        if (!(tokenClaims instanceof AccessTokenClaimsSet) || tokenClaims.getClaimsSet() == null) {
            log.error("{} Claims validation failed, unable to locate access token claims set to validate",
                    getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
            return;
        }

        log.debug("{} Validating parsed/decoded claims set: {}", getLogPrefix(), tokenClaims.getClaimsSet().toString());
        try {
            claimsValidator.validate(tokenClaims.getClaimsSet(), profileRequestContext);
        } catch (final JWTValidationException e) {
            log.warn("{} Claims validation failed, token is invalid: {}", getLogPrefix(), e.getMessage());
            ActionSupport.buildEvent(profileRequestContext, OidcEventIds.INVALID_GRANT);
            return;
        }

        log.debug("{} Access token {} validated", getLogPrefix(), tokenClaims.getID());
    }

}