/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.authn.impl;

import java.text.ParseException;
import java.util.function.Function;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.security.auth.Subject;
import javax.security.auth.login.LoginException;

import net.shibboleth.idp.authn.AbstractCredentialValidator;
import net.shibboleth.idp.authn.AuthnEventIds;
import net.shibboleth.idp.authn.context.AuthenticationContext;
import net.shibboleth.idp.authn.principal.UsernamePrincipal;
import net.shibboleth.oidc.authn.context.OAuth2ClientAuthenticationContext;
import net.shibboleth.idp.profile.context.RelyingPartyContext;
import net.shibboleth.oidc.jwt.claims.ClaimsValidator;
import net.shibboleth.oidc.jwt.claims.JWTValidationException;
import net.shibboleth.oidc.profile.config.navigate.ClaimsValidatorLookupFunction;
import net.shibboleth.oidc.security.impl.JWTSignatureValidationUtil;
import net.shibboleth.utilities.java.support.annotation.constraint.NotEmpty;
import net.shibboleth.utilities.java.support.annotation.constraint.ThreadSafeAfterInit;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;

import org.opensaml.messaging.context.navigate.ChildContextLookup;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.xmlsec.context.SecurityParametersContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod;
import com.nimbusds.oauth2.sdk.auth.JWTAuthentication;
import com.nimbusds.oauth2.sdk.id.ClientID;

/**
 * A validator that handles authentication via signed JWT.
 * 
 * <p>For now, implemented via Nimbus APIs.</p>
 * 
 * TODO: there will be additional validation checks added once implemented on the older branch
 */
@ThreadSafeAfterInit
public class JWTCredentialValidator extends AbstractCredentialValidator {
    
    /** Class logger. */
    @Nonnull private final Logger log = LoggerFactory.getLogger(JWTCredentialValidator.class);
    
    /** Strategy that will return {@link OAuth2ClientAuthenticationContext}. */
    @Nonnull private Function<ProfileRequestContext,OAuth2ClientAuthenticationContext> clientAuthContextLookupStrategy;

    /** Strategy used to locate the {@link SecurityParametersContext} to use for verification. */
    @Nonnull private Function<ProfileRequestContext,SecurityParametersContext> securityParametersLookupStrategy;

    /** Strategy used to obtain {@link ClaimsValidator}. */
    @Nonnull private Function<ProfileRequestContext,ClaimsValidator> claimsValidatorLookupStrategy;
    
    /** Whether to save the JWT in the Java Subject's public credentials. */
    private boolean saveTokenToCredentialSet;
    
    /** Constructor. */
    public JWTCredentialValidator() {
        // PRC -> AuthenticationContext -> OAuth2ClientAuthenticationContext
        clientAuthContextLookupStrategy = new ChildContextLookup<>(OAuth2ClientAuthenticationContext.class).compose(
                new ChildContextLookup<>(AuthenticationContext.class));
        // PRC -> RP -> SPC
        securityParametersLookupStrategy = new ChildContextLookup<>(SecurityParametersContext.class).compose(
                new ChildContextLookup<>(RelyingPartyContext.class));
        
        claimsValidatorLookupStrategy = new ClaimsValidatorLookupFunction();
    }
    
    /**
     * Set the strategy used to return the {@link OAuth2ClientAuthenticationContext}.
     * 
     * @param strategy lookup strategy
     */
    public void setOAuth2ClientAuthenticationLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,OAuth2ClientAuthenticationContext> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
    
        clientAuthContextLookupStrategy =
                Constraint.isNotNull(strategy, "OAuth2ClientAuthenticationContext lookup strategy cannot be null");
    }
    
    /**
     * Set the strategy used to locate the {@link SecurityParametersContext} to use.
     * 
     * @param strategy lookup strategy
     */
    public void setSecurityParametersLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,SecurityParametersContext> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        securityParametersLookupStrategy =
                Constraint.isNotNull(strategy, "SecurityParameterContext lookup strategy cannot be null");
    }

    /**
     * Set the strategy used to locate {@link ClaimsValidator} used.
     * 
     * @param strategy lookup strategy
     */
    public void setClaimsValidatorLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,ClaimsValidator> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        claimsValidatorLookupStrategy =
                Constraint.isNotNull(strategy, "ClaimsValidator lookup strategy cannot be null");
    }
    
    /**
     * Set whether to save the JWT in the Java Subject's public credentials.
     * 
     * <p>Defaults to true</p>
     * 
     * @param flag flag to set
     */
    public void setSaveTokenToCredentialSet(final boolean flag) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        saveTokenToCredentialSet = flag;
    }
    
// Checkstyle: CyclomaticComplexity OFF
    /** {@inheritDoc} */
    @Override
    @Nullable protected Subject doValidate(@Nonnull final ProfileRequestContext profileRequestContext,
            @Nonnull final AuthenticationContext authenticationContext,
            @Nullable final WarningHandler warningHandler,
            @Nullable final ErrorHandler errorHandler) throws Exception {
        
        final OAuth2ClientAuthenticationContext clientAuthContext =
                clientAuthContextLookupStrategy.apply(profileRequestContext);
        if (clientAuthContext == null || clientAuthContext.getClientAuthentication() == null) {
            log.debug("{} No OAuth 2.0 client authentication information found", getLogPrefix());
            return null;
        }
        
        final ClientAuthentication clientAuth = clientAuthContext.getClientAuthentication();
        if (!ClientAuthenticationMethod.CLIENT_SECRET_JWT.equals(clientAuth.getMethod()) &&
                !ClientAuthenticationMethod.PRIVATE_KEY_JWT.equals(clientAuth.getMethod())) {
            log.debug("{} OAuth client authentication for '{}' of unsupported type: {}", getLogPrefix(),
                    clientAuth.getClientID(), clientAuth.getMethod());
            return null;
        }
        
        if (!(clientAuth instanceof JWTAuthentication)) {
            log.warn("{} OAuth client authentication object of unexpected type: {}", getLogPrefix(),
                    clientAuth.getClass().getSimpleName());
            log.info("{} Login by '{}' failed", getLogPrefix(), clientAuth.getClientID());
            final LoginException e = new LoginException(AuthnEventIds.INVALID_CREDENTIALS); 
            if (errorHandler != null) { 
                errorHandler.handleError(profileRequestContext, authenticationContext, e,
                        AuthnEventIds.INVALID_CREDENTIALS);
            }
            throw e;
        }
        
        final JWTAuthentication jwtAuth = (JWTAuthentication) clientAuth;
        final String errorEventId = JWTSignatureValidationUtil.validateSignature(
                securityParametersLookupStrategy.apply(profileRequestContext), jwtAuth.getClientAssertion(),
                AuthnEventIds.INVALID_CREDENTIALS);
        if (errorEventId != null) {
            log.info("{} Login by '{}' failed", getLogPrefix(), clientAuth.getClientID());
            final LoginException e = new LoginException(errorEventId); 
            if (errorHandler != null) { 
                errorHandler.handleError(profileRequestContext, authenticationContext, e,
                        AuthnEventIds.INVALID_CREDENTIALS);
            }
            throw e;
        }

        try {
            validateJWTClaims(profileRequestContext, jwtAuth.getClientAssertion(), clientAuth.getClientID());
        } catch (final Exception e) {
            log.info("{} Login by '{}' failed", getLogPrefix(), clientAuth.getClientID());
            if (errorHandler != null) { 
                errorHandler.handleError(profileRequestContext, authenticationContext, e,
                        AuthnEventIds.INVALID_CREDENTIALS);
            }
            throw e;
        }
        
        log.info("{} Login by '{}' succeeded", getLogPrefix(), clientAuth.getClientID());
        
        return populateSubject(clientAuth.getClientID(), jwtAuth.getClientAssertion());
    }    
// Checkstyle: CyclomaticComplexity ON

    /**
     * Validates the contents of the given JWT against the requirements set in the OIDC core specification section 9.
     * 
     * @param jwt JWT to be validated
     * @param clientId client ID from which the JWT is coming from
     * @param profileRequestContext profile request context
     * 
     * @throws ParseException if unable to parse the claim set
     * @throws JWTValidationException if the claims fail to validate
     */
    protected void validateJWTClaims(@Nonnull final ProfileRequestContext profileRequestContext,
            @Nonnull final SignedJWT jwt, @Nonnull final ClientID clientId)
                    throws ParseException, JWTValidationException {

        final ClaimsValidator validator = claimsValidatorLookupStrategy.apply(profileRequestContext);
        if (validator == null) {
            log.warn("{} JWT validation failed for client '{}': No ClaimsValidator found in configuration",
                    getLogPrefix(), clientId);
            throw new JWTValidationException("No ClaimsValidator found in configuration");
        }
        
        final JWTClaimsSet claimsSet;
        try {
            claimsSet = jwt.getJWTClaimsSet();
        } catch (final ParseException e) {
            log.warn("{} Could not parse the JWT from client '{}' into claims set", getLogPrefix(), clientId);
            throw e;
        }
        
        try {
            validator.validate(claimsSet, profileRequestContext);
        } catch (final JWTValidationException e) {
            log.warn("{} JWT validation failed for client '{}': {}", getLogPrefix(), clientId, e.getMessage());
            throw e;
        }
    } 
    
   /**
    * Builds a subject with "standard" content from the validation.
    *
    * @param clientId client ID
    * @param token the token validated
    * 
    * @return the decorated subject
    */
   @Nonnull protected Subject populateSubject(@Nonnull @NotEmpty final ClientID clientId,
           @Nonnull final SignedJWT token) {
      
       final Subject subject = new Subject();
       subject.getPrincipals().add(new UsernamePrincipal(clientId.getValue()));
       if (saveTokenToCredentialSet) {
           subject.getPublicCredentials().add(token);
       }
      
       return super.populateSubject(subject);
   }

}