/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.oidc.op.profile.impl;

import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.opensaml.messaging.context.MessageContext;
import org.opensaml.messaging.context.navigate.ChildContextLookup;
import org.opensaml.profile.action.ActionSupport;
import org.opensaml.profile.action.EventIds;
import org.opensaml.profile.context.ProfileRequestContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.openid.connect.sdk.rp.OIDCClientMetadata;
import com.nimbusds.openid.connect.sdk.rp.OIDCClientRegistrationRequest;

import net.minidev.json.JSONObject;
import net.shibboleth.idp.plugin.oidc.op.messaging.context.OIDCClientRegistrationMetadataPolicyContext;
import net.shibboleth.idp.profile.AbstractProfileAction;
import net.shibboleth.oidc.metadata.policy.MetadataPolicy;
import net.shibboleth.oidc.metadata.policy.impl.DefaultMetadataPolicyEnforcer;
import net.shibboleth.utilities.java.support.collection.Pair;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;

/**
 * Validates the incoming dynamic client registration request against the metadata policy stored in the
 * {@link OIDCClientRegistrationMetadataPolicyContext}. The policy-enforced request metadata is stored via
 * {@link OIDCClientRegistrationMetadataPolicyContext#setPolicyEnforcedMetadata(OIDCClientMetadata)}.
 */
public class ValidateRegistrationRequestMetadata extends AbstractProfileAction {

    /** Class logger. */
    @Nonnull private final Logger log = LoggerFactory.getLogger(ValidateRegistrationRequestMetadata.class);
    
    /** Strategy that will return {@link OIDCClientRegistrationMetadataPolicyContext}. */
    @Nonnull private Function<MessageContext, OIDCClientRegistrationMetadataPolicyContext>
        registrationMetadataPolicyContextLookupStrategy;
    
    /** Function used for enforcing the metadata policy. */
    @Nonnull private BiFunction<Object,MetadataPolicy,Pair<Object,Boolean>> metadataPolicyEnforcer;

    /** The OIDCClientRegistrationRequest to validate. */
    @Nullable private OIDCClientRegistrationRequest request;
    
    /** The metadata policy context to operate on. */
    @Nullable private OIDCClientRegistrationMetadataPolicyContext registrationMetadataPolicyContext;

    /** The metadata policy used for validation. */
    @Nullable private Map<String, MetadataPolicy> metadataPolicy;
    
    /**
     * Constructor.
     */
    public ValidateRegistrationRequestMetadata() {
        registrationMetadataPolicyContextLookupStrategy = 
                new ChildContextLookup<>(OIDCClientRegistrationMetadataPolicyContext.class);

        metadataPolicyEnforcer = new DefaultMetadataPolicyEnforcer();
    }
    
    /**
     * Set the strategy that will return {@link OIDCClientRegistrationMetadataPolicyContext}.
     * 
     * @param strategy Strategy that will return {@link OIDCClientRegistrationMetadataPolicyContext}.
     */
    public void setRegistrationMetadataPolicyContextLookupStrategy(
            @Nonnull final Function<MessageContext, OIDCClientRegistrationMetadataPolicyContext> strategy) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        registrationMetadataPolicyContextLookupStrategy = Constraint.isNotNull(strategy,
                "Registration metadata policy context lookup strategy cannot be null");
    }
 
    /**
     * Set the function used for enforcing the metadata policy.
     * 
     * @param function Function used for enforcing the metadata policy.
     */
    public void setMetadataPolicyEnforcer(
            @Nonnull final BiFunction<Object,MetadataPolicy,Pair<Object,Boolean>> function) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        metadataPolicyEnforcer = Constraint.isNotNull(function, "The metadata policy enforcer cannot be null");
    }

    /** {@inheritDoc} */
    @Override
    protected boolean doPreExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        if (!super.doPreExecute(profileRequestContext)) {
            return false;
        }
        
        final MessageContext messageContext = profileRequestContext.getInboundMessageContext();

        if (messageContext == null) {
            log.debug("{} No inbound message context associated with this profile request", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, EventIds.INVALID_PROFILE_CTX);
            return false;
        }
        
        final Object message = messageContext.getMessage();
        if (message == null || !(message instanceof OIDCClientRegistrationRequest)) {
            log.debug("{} No inbound message associated with this profile request", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, EventIds.INVALID_MSG_CTX);
            return false;                        
        }
        request = (OIDCClientRegistrationRequest) message;

        
        registrationMetadataPolicyContext = registrationMetadataPolicyContextLookupStrategy.apply(messageContext);
        if (registrationMetadataPolicyContext == null) {
            log.debug("{} No metadata policy context associated with this request", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, EventIds.INVALID_MSG_CTX);
            return false;                        
        }
        metadataPolicy = registrationMetadataPolicyContext.getMetadataPolicy();
        
        return true;
    }
    
    /** {@inheritDoc} */
    @Override
    protected void doExecute(@Nonnull final ProfileRequestContext profileRequestContext) {
        if (metadataPolicy == null || metadataPolicy.isEmpty()) {
            log.debug("{} No metadata policy found, setting the request as policy enforced", getLogPrefix());
            registrationMetadataPolicyContext.setPolicyEnforcedMetadata(request.getOIDCClientMetadata());
            return;
        }
        log.debug("{} Metadata policy used for request validation: {}", getLogPrefix(), metadataPolicy);

        boolean compliant = true;

        final JSONObject requestMetadata = request.getOIDCClientMetadata().toJSONObject();

        for (final String claim : metadataPolicy.keySet()) {
            final MetadataPolicy policy = metadataPolicy.get(claim);
            final Object value = requestMetadata.get(claim);
            log.debug("{} Claim {} set in policy included in the request: {}", getLogPrefix(), claim,
                    value == null);
            final Pair<Object,Boolean> result = metadataPolicyEnforcer.apply(value, policy);
            if (!result.getSecond()) {
                log.warn("{} Metadata claim {} is not compliant with the policy", getLogPrefix(), claim);
                compliant = false;
            } else {
                log.trace("{} Validation result is OK for claim {}", getLogPrefix(), claim);
                requestMetadata.put(claim, result.getFirst());
            }
        }
        
        if (!compliant) {
            log.warn("{} The requested metadata is not compliant with the policy", getLogPrefix());
            ActionSupport.buildEvent(profileRequestContext, EventIds.INVALID_MESSAGE);
        }
        
        try {
            final OIDCClientMetadata enforcedMetadata = OIDCClientMetadata.parse(requestMetadata);
            registrationMetadataPolicyContext.setPolicyEnforcedMetadata(enforcedMetadata);
            log.debug("{} The enforced metadata stored in context", getLogPrefix());
        } catch (final ParseException e) {
            log.error("{} Could not parse the enforced metadata", getLogPrefix(), e);
            ActionSupport.buildEvent(profileRequestContext, EventIds.INVALID_MSG_CTX);
        }
    }

}