/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.admin.impl;

import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.time.format.DateTimeParseException;
import java.util.Map;
import java.util.function.Function;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletResponse;

import org.opensaml.messaging.context.MessageContext;
import org.opensaml.profile.action.ActionSupport;
import org.opensaml.profile.action.EventIds;
import org.opensaml.profile.context.ProfileRequestContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.nimbusds.oauth2.sdk.AccessTokenResponse;
import com.nimbusds.oauth2.sdk.TokenResponse;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import com.nimbusds.oauth2.sdk.token.Tokens;

import net.shibboleth.idp.authn.context.AuthenticationContext;
import net.shibboleth.idp.authn.context.SubjectContext;
import net.shibboleth.idp.plugin.oidc.op.cli.IssueRegistrationAccessTokenArguments;
import net.shibboleth.idp.plugin.oidc.op.token.support.RegistrationClaimsSet;
import net.shibboleth.idp.profile.context.navigate.ResponderIdLookupFunction;
import net.shibboleth.idp.profile.function.SpringFlowScopeLookupFunction;
import net.shibboleth.oidc.metadata.policy.MetadataPolicy;
import net.shibboleth.utilities.java.support.annotation.constraint.NonnullAfterInit;
import net.shibboleth.utilities.java.support.annotation.constraint.NotEmpty;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;
import net.shibboleth.utilities.java.support.logic.FunctionSupport;
import net.shibboleth.utilities.java.support.primitive.StringSupport;
import net.shibboleth.utilities.java.support.security.AccessControlService;
import net.shibboleth.utilities.java.support.security.DataSealer;
import net.shibboleth.utilities.java.support.security.DataSealerException;
import net.shibboleth.utilities.java.support.security.IdentifierGenerationStrategy;
import net.shibboleth.utilities.java.support.security.impl.SecureRandomIdentifierGenerationStrategy;

/**
 * Action that issues access token to be used for the OIDC dynamic registration endpoint.
 * 
 * <p>On success, {@link AccessTokenResponse} is built and attached as a message for the outbound message context. Also
 * a proceed event is built. On error, a non-proceed event is built.</p>
 * 
 * <p>Several access control checks are made to named policies in the case that certain options are supplied.</p>
 * 
 * @event {@link EventIds#PROCEED_EVENT_ID}
 * @event {@link EventIds#INVALID_PROFILE_CTX}
 * @event {@link EventIds#IO_ERROR}
 * 
 * @since 3.1.0
 */
public class IssueRegistrationAccessToken extends AbstractAdminApiProfileAction {
    
    /** Class logger. */
    @Nonnull private Logger log = LoggerFactory.getLogger(IssueRegistrationAccessToken.class);

    /** Data sealer for handling access token. */
    @NonnullAfterInit private DataSealer dataSealer;

    /** Strategy used to locate the {@link IdentifierGenerationStrategy} to use. */
    @Nonnull private Function<ProfileRequestContext,IdentifierGenerationStrategy> idGeneratorLookupStrategy;

    /** Access control service. */
    @NonnullAfterInit private AccessControlService accessControlService;

    /** Name of access control policy governing policyLocation acceptance. */
    @Nullable @NotEmpty private String policyLocationPolicyName;

    /** Name of access control policy governing policyId acceptance. */
    @Nullable @NotEmpty private String policyIdPolicyName;

    /** Name of access control policy governing clientId acceptance. */
    @Nullable @NotEmpty private String clientIdPolicyName;

    /** Lookup function for the metadata policy. */
    @NonnullAfterInit private Function<ProfileRequestContext,Map<String,MetadataPolicy>> metadataPolicyLookupStrategy;

    /** Lookup function for the token lifetime. */
    @Nonnull private Function<ProfileRequestContext,String> tokenLifetimeLookupStrategy;
    
    /** Lookup function for the token issuer. */
    @NonnullAfterInit private Function<ProfileRequestContext,String> issuerLookupStrategy;

    /** Lookup function for the policy location. */
    @Nonnull private Function<ProfileRequestContext,String> policyLocationLookupStrategy;

    /** Lookup function for the policy identifier. */
    @Nonnull private Function<ProfileRequestContext,String> policyIdLookupStrategy;

    /** Lookup function for the client identifier. */
    @Nonnull private Function<ProfileRequestContext,String> clientIdLookupStrategy;

    /** Lookup function for the flag signaling replacement use of the token. */
    @Nonnull private Function<ProfileRequestContext,String> replacementLookupStrategy;

    /** The identifier generator to use. */
    @Nullable private IdentifierGenerationStrategy idGenerator;

    /** The resolved metadata policy. */
    @Nullable private Map<String,MetadataPolicy> metadataPolicy;
    
    /** The token issuer. */
    @Nonnull private String issuer;
    
    /** The policy location. */
    @Nullable private String policyLocation;

    /** The policy identifier. */
    @Nullable private String policyId;

    /** The client identifier. */
    @Nullable private String clientId;

    /** The token lifetime. */
    @Nullable private Duration defaultTokenLifetime;

    /** The token lifetime. */
    @Nullable private Duration tokenLifetime;
    
    /**
     * Constructor.
     */
    public IssueRegistrationAccessToken() {
        idGeneratorLookupStrategy = FunctionSupport.constant(new SecureRandomIdentifierGenerationStrategy());
        issuerLookupStrategy = new ResponderIdLookupFunction();
        tokenLifetimeLookupStrategy =
                new SpringFlowScopeLookupFunction(IssueRegistrationAccessTokenArguments.URL_PARAM_LIFETIME);
        policyLocationLookupStrategy =
                new SpringFlowScopeLookupFunction(IssueRegistrationAccessTokenArguments.URL_PARAM_POLICY_LOCATION);
        policyIdLookupStrategy =
                new SpringFlowScopeLookupFunction(IssueRegistrationAccessTokenArguments.URL_PARAM_POLICY_ID);
        clientIdLookupStrategy =
                new SpringFlowScopeLookupFunction(IssueRegistrationAccessTokenArguments.URL_PARAM_CLIENT_ID);
        replacementLookupStrategy =
                new SpringFlowScopeLookupFunction(IssueRegistrationAccessTokenArguments.URL_PARAM_REPLACEMENT);
        
        defaultTokenLifetime = Duration.ofDays(1);
    }
    
    /**
     * Set the data sealer for handling access token.
     * 
     * @param sealer data sealer.
     */
    public void setSealer(@Nonnull final DataSealer sealer) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        dataSealer = Constraint.isNotNull(sealer, "Data sealer cannot be null");
    }
    
    /**
     * Set the {@link AccessControlService} to use.
     * 
     * @param acs service to use
     */
    public void setAccessControlService(@Nonnull final AccessControlService acs) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        accessControlService = Constraint.isNotNull(acs, "AccessControlService cannot be null");
    }

    /**
     * Set the strategy used to locate the {@link IdentifierGenerationStrategy} to use.
     * 
     * @param strategy lookup strategy
     */
    public void setIdentifierGeneratorLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,IdentifierGenerationStrategy> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        idGeneratorLookupStrategy =
                Constraint.isNotNull(strategy, "IdentifierGenerationStrategy lookup strategy cannot be null");
    }

    /**
     * Set a lookup strategy for the token issuer.
     * 
     * @param strategy lookup strategy
     */
    public void setIssuerLookupStrategy(@Nonnull final Function<ProfileRequestContext,String> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        issuerLookupStrategy = Constraint.isNotNull(strategy, "Issuer lookup strategy cannot be null");
    }

    /**
     * Set a lookup strategy for the metadata policy.
     * 
     * @param strategy lookup strategy
     */
    public void setMetadataPolicyLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,Map<String,MetadataPolicy>> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        metadataPolicyLookupStrategy =
                Constraint.isNotNull(strategy, "Metadata policy lookup strategy cannot be null");
    }
    
    /**
     * Set a lookup strategy for the token lifetime.
     * 
     * @param strategy lookup strategy
     */
    public void setTokenLifetimeLookupStrategy(@Nonnull final Function<ProfileRequestContext,String> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        tokenLifetimeLookupStrategy = Constraint.isNotNull(strategy, "Token lifetime lookup strategy cannot be null");
    }

    /**
     * Set a lookup strategy for the metadata policy location.
     * 
     * @param strategy lookup strategy
     */
    public void setPolicyLocationLookupStrategy(@Nonnull final Function<ProfileRequestContext,String> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        policyLocationLookupStrategy = Constraint.isNotNull(strategy, "Policy location lookup strategy cannot be null");
    }

    /**
     * Set a lookup strategy for the relying party identifier.
     * 
     * @param strategy lookup strategy
     */
    public void setPolicyIdLookupStrategy(@Nonnull final Function<ProfileRequestContext,String> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        policyIdLookupStrategy = Constraint.isNotNull(strategy, "Policy ID lookup strategy cannot be null");
    }

    /**
     * Set a lookup strategy for the client identifier.
     * 
     * @param strategy lookup strategy
     */
    public void setClientIdLookupStrategy(@Nonnull final Function<ProfileRequestContext,String> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        clientIdLookupStrategy = Constraint.isNotNull(strategy, "Client ID lookup strategy cannot be null");
    }

    /**
     * Set a lookup strategy for the flag signaling registration replacement is allowed.
     * 
     * @param strategy lookup strategy
     */
    public void setReplacementLookupStrategy(@Nonnull final Function<ProfileRequestContext,String> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);

        replacementLookupStrategy = Constraint.isNotNull(strategy, "Replacement lookup strategy cannot be null");
    }

    /**
     * Set an explicit policy name to apply governing policyLocation usage.
     * 
     * @param name  policy name
     */
    public void setPolicyLocationPolicyName(@Nullable @NotEmpty final String name) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        policyLocationPolicyName = StringSupport.trimOrNull(name);
    }

    /**
     * Set an explicit policy name to apply governing policyId usage.
     * 
     * @param name  policy name
     */
    public void setPolicyIdPolicyName(@Nullable @NotEmpty final String name) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        policyIdPolicyName = StringSupport.trimOrNull(name);
    }

    /**
     * Set an explicit policy name to apply governing clientId usage.
     * 
     * @param name  policy name
     */
    public void setClientIdPolicyName(@Nullable @NotEmpty final String name) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        clientIdPolicyName = StringSupport.trimOrNull(name);
    }
    
    /**
     * Set the default token lifetime.
     * 
     * @param lifetime  token lifetime
     */
    public void setDefaultTokenLifetime(@Nonnull final Duration lifetime) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        defaultTokenLifetime = Constraint.isNotNull(lifetime, "Default token lifetime cannot be null");
    }

    /** {@inheritDoc} */
    @Override
    protected void doInitialize() throws ComponentInitializationException {
        super.doInitialize();
        
        if (dataSealer == null) {
            throw new ComponentInitializationException("DataSealer cannot be null");
        }
        
        if (accessControlService == null) {
            throw new ComponentInitializationException("AccessControlService cannot be null");
        }
        
        if (metadataPolicyLookupStrategy == null) {
            throw new ComponentInitializationException("Metadata policy lookup strategy cannot be null");
        }
    }

    // Checkstyle: CyclomaticComplexity OFF

    /** {@inheritDoc} */
    @Override
    protected boolean doPreExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        
        if (!super.doPreExecute(profileRequestContext)) {
            return false;
        }

        idGenerator = idGeneratorLookupStrategy.apply(profileRequestContext);

        try {
            if (idGenerator == null) {
                log.error("{} No identifier generation strategy", getLogPrefix());
                sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
                        "Internal Server Error", "System misconfiguration.");
                return false;
            }

            issuer = issuerLookupStrategy.apply(profileRequestContext);
            if (issuer == null) {
                log.error("{} No issuer could be resolved", getLogPrefix());
                sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
                        "Internal Server Error", "System misconfiguration.");
                return false;
            }

            policyLocation = policyLocationLookupStrategy.apply(profileRequestContext);
            policyId = policyIdLookupStrategy.apply(profileRequestContext);
            metadataPolicy = metadataPolicyLookupStrategy.apply(profileRequestContext);
            if (policyLocation != null && metadataPolicy == null) {
                log.warn("{} No metadata policy could be resolved from the given location: {}", getLogPrefix(),
                        policyLocation);
                sendError(HttpServletResponse.SC_BAD_REQUEST,
                        "Invalid Request", "No metadata policy or policy ID could be resolved.");
                return false;
            }
            if (metadataPolicy == null && policyId == null) {
                log.warn("{} No metadata policy or policy ID could be resolved", getLogPrefix());
                sendError(HttpServletResponse.SC_BAD_REQUEST,
                        "Invalid Request", "No metadata policy or policy ID could be resolved.");
                return false;
            }
        } catch (final IOException e) {
            log.error("{} I/O error issuing API response", getLogPrefix(), e);
            ActionSupport.buildEvent(profileRequestContext, EventIds.IO_ERROR);
            return false;
        }
        
        clientId = clientIdLookupStrategy.apply(profileRequestContext);
        
        final String lifetimeString = tokenLifetimeLookupStrategy.apply(profileRequestContext);
        if (lifetimeString != null) {
            try {
                tokenLifetime = Duration.parse(lifetimeString);
                if (tokenLifetime.compareTo(defaultTokenLifetime) > 0) {
                    log.warn("Requested token lifetime greater than default, lowering to default", getLogPrefix());
                    tokenLifetime = defaultTokenLifetime;
                }
            } catch (final DateTimeParseException e) {
                log.warn("{} Token lifetime was not in a supported format", getLogPrefix(), e);
            }
        } else {
            log.debug("{} No token lifetime specified, using default", getLogPrefix());
            tokenLifetime = defaultTokenLifetime;
        }
        
        return true;
    }

    // Checkstyle: CyclomaticComplexity ON

    /** {@inheritDoc} */
    @Override
    protected void doExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        
        if (!checkAccess(profileRequestContext)) {
            return;
        }
        
        final String id = idGenerator.generateIdentifier();
        
        final Instant now = Instant.now();
        final Instant exp = now.plus(tokenLifetime);

        final RegistrationClaimsSet.Builder builder = new RegistrationClaimsSet.Builder(id)
                .withIssuer(issuer)
                .withIssuedAt(now)
                .withExpiration(exp)
                .withMetadata(metadataPolicy)
                .withRelyingPartyId(policyId);
        
        if (clientId != null) {
            builder.withClientId(clientId)
                .withReplacement(Boolean.valueOf(replacementLookupStrategy.apply(profileRequestContext)));
        }

        addAuthenticationClaims(profileRequestContext, builder);
        
        final RegistrationClaimsSet claimsSet = builder.build();
        
        final AccessToken accessToken;
        
        try {
            final String value = getObjectMapper().writeValueAsString(claimsSet);
            log.debug("{} Built the following JSON to be sealed {}", getLogPrefix(), value);
            final String encryptedValue = dataSealer.wrap(value, claimsSet.getExpiration());
            log.debug("{} Encrypted the JSON into {}", getLogPrefix(), encryptedValue);
            accessToken = new BearerAccessToken(encryptedValue, tokenLifetime.getSeconds(), null);
        } catch (final JsonProcessingException e) {
            log.error("{} Could not build JSON", getLogPrefix(), e);
            ActionSupport.buildEvent(profileRequestContext, EventIds.IO_ERROR);
            return;
        } catch (final DataSealerException e) {
            log.error("{} Could not encrypt the claims set", getLogPrefix(), e);
            ActionSupport.buildEvent(profileRequestContext, EventIds.IO_ERROR);
            return;
        }
        
        final TokenResponse response = new AccessTokenResponse(new Tokens(accessToken, null));
        final MessageContext mc = new MessageContext();
        mc.setMessage(response);
        profileRequestContext.setOutboundMessageContext(mc);
    }

    // Checkstyle: CyclomaticComplexity OFF

    /**
     * Check access policies.
     * 
     * @param profileRequestContext current profile request context
     * 
     * @return true iff checks pass
     */
    private boolean checkAccess(@Nonnull final ProfileRequestContext profileRequestContext) {
        try {
            if (policyId != null) {
                if (policyIdPolicyName == null) {
                    log.warn("{} No policy name govering policyId usage, disallowing access", getLogPrefix());
                    sendError(HttpServletResponse.SC_FORBIDDEN,
                            "Access Denied", "No policy name govering policyId usage, disallowing access.");
                    return false;
                } else if (!accessControlService.getInstance(policyIdPolicyName).checkAccess(getHttpServletRequest(),
                        "read", policyId)) {
                    sendError(HttpServletResponse.SC_FORBIDDEN,
                            "Access Denied", "Operation is not allowed with the current policy.");
                    return false;
                }
            }

            if (policyLocation != null) {
                if (policyLocationPolicyName == null) {
                    log.warn("{} No policy name govering policyLocation usage, disallowing access", getLogPrefix());
                    sendError(HttpServletResponse.SC_FORBIDDEN,
                            "Access Denied", "No policy name govering policyId usage, disallowing access.");
                    return false;
                } else if (!accessControlService.getInstance(policyLocationPolicyName).checkAccess(
                        getHttpServletRequest(), "read", policyLocation)) {
                    sendError(HttpServletResponse.SC_FORBIDDEN,
                            "Access Denied", "Operation is not allowed with the current policy.");
                    return false;
                }
            }

            if (clientId != null) {
                if (clientIdPolicyName == null) {
                    log.warn("{} No policy name govering clientId usage, disallowing access", getLogPrefix());
                    sendError(HttpServletResponse.SC_FORBIDDEN,
                            "Access Denied", "No policy name govering policyId usage, disallowing access.");
                    return false;
                } else if (!accessControlService.getInstance(clientIdPolicyName).checkAccess(getHttpServletRequest(),
                        "write", clientId)) {
                    sendError(HttpServletResponse.SC_FORBIDDEN,
                            "Access Denied", "Operation is not allowed with the current policy.");
                    return false;
                }
            }
        } catch (final IOException e) {
            log.error("{} I/O error issuing API response", getLogPrefix(), e);
            ActionSupport.buildEvent(profileRequestContext, EventIds.IO_ERROR);
            return false;
        }
        
        return true;
    }

    // Checkstyle: CyclomaticComplexity ON

    /**
     * Decorate the token with authentication-related claims.
     * 
     * @param profileRequestContext profile request context
     * @param builder claims set builder
     */
    private void addAuthenticationClaims(@Nonnull final ProfileRequestContext profileRequestContext,
            @Nonnull final RegistrationClaimsSet.Builder builder) {
        
        final AuthenticationContext authnContext = profileRequestContext.getSubcontext(AuthenticationContext.class);
        if (authnContext != null) {
            if (authnContext.getAuthenticationResult() != null) {
                builder.withAuthTime(authnContext.getAuthenticationResult().getAuthenticationInstant());
            }
        }
        
        final SubjectContext subjectContext = profileRequestContext.getSubcontext(SubjectContext.class);
        if (subjectContext != null && subjectContext.getPrincipalName() != null) {
            builder.withPrincipal(subjectContext.getPrincipalName());
        }
    }

}