package com.atlassian.crowd.directory.rest;

import com.atlassian.crowd.directory.query.GraphQuery;
import com.atlassian.crowd.directory.query.MicrosoftGraphQueryParam;
import com.atlassian.crowd.directory.query.ODataSelect;
import com.atlassian.crowd.directory.query.ODataTop;
import com.atlassian.crowd.directory.rest.endpoint.AzureApiUriResolver;
import com.atlassian.crowd.directory.rest.entity.GraphDirectoryObjectList;
import com.atlassian.crowd.directory.rest.entity.PageableGraphList;
import com.atlassian.crowd.directory.rest.entity.delta.GraphDeltaQueryGroupList;
import com.atlassian.crowd.directory.rest.entity.delta.GraphDeltaQueryUserList;
import com.atlassian.crowd.directory.rest.entity.group.GraphGroupList;
import com.atlassian.crowd.directory.rest.entity.user.GraphUsersList;
import com.atlassian.crowd.directory.rest.util.IoUtilsWrapper;
import com.atlassian.crowd.directory.rest.util.JerseyLoggingFilter;
import com.atlassian.crowd.exception.OperationFailedException;
import com.atlassian.security.xml.SecureXmlParserFactory;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Charsets;
import com.sun.jersey.api.client.Client;
import com.sun.jersey.api.client.ClientResponse;
import com.sun.jersey.api.client.UniformInterfaceException;
import com.sun.jersey.api.client.WebResource;
import com.sun.jersey.api.client.config.ClientConfig;
import com.sun.jersey.client.impl.ClientRequestImpl;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import org.slf4j.Logger;
import org.w3c.dom.Document;
import org.w3c.dom.NamedNodeMap;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;

import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.UriBuilder;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpression;
import javax.xml.xpath.XPathExpressionException;
import javax.xml.xpath.XPathFactory;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.Collections;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.atlassian.crowd.directory.query.MicrosoftGraphQueryParams.asQueryParams;
import static org.slf4j.LoggerFactory.getLogger;

/**
 * The client used to communicate with Azure AD via Microsoft Graph. Expects and returns Microsoft Graph REST entities.
 */
public class AzureAdRestClient {
    private static final Logger log = getLogger(AzureAdRestClient.class);
    public static final String GRAPH_API_VERSION = "/v1.0";

    public static final String GRAPH_USERS_ENDPOINT_SUFFIX = "users";
    public static final String GRAPH_GROUPS_ENDPOINT_SUFFIX = "groups";
    public static final String METADATA_ENDPOINT_SUFFIX = "$metadata";
    public static final String MEMBER_OF_NAVIGATIONAL_PROPERTY = "memberOf";
    public static final String MEMBERS_NAVIGATIONAL_PROPERTY = "members";
    public static final String DELTA_QUERY_ENDPOINT_SUFFIX = "delta";
    public static final String TRASH_ENDPOINT_SUFFIX = "directory/deleteditems";

    private static final String SCHEMA_XPATH = "/Edmx/DataServices/Schema";
    private static final String DELTA_RETURN_PATHS_XPATH = SCHEMA_XPATH + "/Function[@Name='delta']/ReturnType";
    private static final String CHARSET_PARAMETER_NAME = "charset";
    private static final String ALIAS_ATTRIBUTE_NAME = "Alias";
    private static final String RETURN_TYPE_ATTRIBUTE_NAME = "Type";
    private static final String NAMESPACE_ATTRIBUTE_NAME = "Namespace";
    public static final String COLLECTION_TYPE_FORMAT = "Collection(%s.%s)";
    public static final String USER_SUFFIX = "user";
    public static final String GROUP_SUFFIX = "group";

    private final Client client;
    private final String graphBaseEndpoint;
    private final IoUtilsWrapper ioUtilsWrapper;

    @VisibleForTesting
    public Client getClient() {
        return client;
    }

    @SuppressFBWarnings(value = "XPATH_INJECTION", justification = "No user input processed")
    public AzureAdRestClient(final Client client, AzureApiUriResolver endpointDataProvider, IoUtilsWrapper ioUtilsWrapper) {
        this.client = client;
        this.graphBaseEndpoint = endpointDataProvider.getGraphApiUrl();
        this.ioUtilsWrapper = ioUtilsWrapper;
    }

    public GraphUsersList searchUsers(final GraphQuery query) throws OperationFailedException {
        return handleRequest(() -> loggingResource(client.resource(getGraphBaseResource()))
                .path(GRAPH_USERS_ENDPOINT_SUFFIX)
                .queryParams(asQueryParams(query.getFilter(), query.getSelect(), query.getLimit()))
                .accept(MediaType.APPLICATION_JSON_TYPE)
                .get(GraphUsersList.class));
    }

    public GraphGroupList searchGroups(final GraphQuery query) throws OperationFailedException {
        return handleRequest(() -> loggingResource(client.resource(getGraphBaseResource()))
                .path(GRAPH_GROUPS_ENDPOINT_SUFFIX)
                .queryParams(asQueryParams(query.getFilter(), query.getSelect(), query.getLimit()))
                .accept(MediaType.APPLICATION_JSON_TYPE)
                .get(GraphGroupList.class));
    }

    public GraphDirectoryObjectList getDirectParentsOfUser(String nameOrExternalId, ODataSelect select) throws OperationFailedException {
        return handleRequest(() -> loggingResource(client.resource(getGraphBaseResource()))
                .path(GRAPH_USERS_ENDPOINT_SUFFIX)
                .path(nameOrExternalId)
                .path(MEMBER_OF_NAVIGATIONAL_PROPERTY)
                .queryParams(asQueryParams(ODataTop.FULL_PAGE, select))
                .accept(MediaType.APPLICATION_JSON_TYPE)
                .get(GraphDirectoryObjectList.class));
    }

    public GraphDirectoryObjectList getDirectParentsOfGroup(String groupId, ODataSelect select) throws OperationFailedException {
        return handleRequest(() -> loggingResource(client.resource(getGraphBaseResource()))
                .path(GRAPH_GROUPS_ENDPOINT_SUFFIX)
                .path(groupId)
                .path(MEMBER_OF_NAVIGATIONAL_PROPERTY)
                .queryParams(asQueryParams(ODataTop.FULL_PAGE, select))
                .accept(MediaType.APPLICATION_JSON_TYPE)
                .get(GraphDirectoryObjectList.class));
    }

    public GraphDirectoryObjectList getDirectChildrenOfGroup(String groupId, ODataSelect select) throws OperationFailedException {
        return handleRequest(() -> loggingResource(client.resource(getGraphBaseResource()))
                .path(GRAPH_GROUPS_ENDPOINT_SUFFIX)
                .path(groupId)
                .path(MEMBERS_NAVIGATIONAL_PROPERTY)
                .queryParams(asQueryParams(select))
                .accept(MediaType.APPLICATION_JSON_TYPE)
                .get(GraphDirectoryObjectList.class));
    }

    public GraphDeltaQueryUserList performUsersDeltaQuery(final MicrosoftGraphQueryParam parameter) throws OperationFailedException {
        return handleRequest(() -> loggingResource(client.resource(getGraphBaseResource()))
                .path(GRAPH_USERS_ENDPOINT_SUFFIX)
                .path(DELTA_QUERY_ENDPOINT_SUFFIX)
                .queryParams(asQueryParams(parameter))
                .accept(MediaType.APPLICATION_JSON_TYPE)
                .get(GraphDeltaQueryUserList.class));
    }

    public GraphDeltaQueryGroupList performGroupsDeltaQuery(final MicrosoftGraphQueryParam... parameters) throws OperationFailedException {
        return handleRequest(() -> loggingResource(client.resource(getGraphBaseResource()))
                .path(GRAPH_GROUPS_ENDPOINT_SUFFIX)
                .path(DELTA_QUERY_ENDPOINT_SUFFIX)
                .queryParams(asQueryParams(parameters))
                .accept(MediaType.APPLICATION_JSON_TYPE)
                .get(GraphDeltaQueryGroupList.class));
    }

    @SuppressFBWarnings(value = "XXE_DOCUMENT", justification = "uses atlassian-secure-xml")
    public boolean supportsDeltaQuery() {
        try {
            log.debug("Fetching metadata from URI {}", UriBuilder.fromUri(getGraphBaseResource()).path(METADATA_ENDPOINT_SUFFIX).build().toString());

            final ClientResponse response = loggingResource(client.resource(getGraphBaseResource()))
                    .path(METADATA_ENDPOINT_SUFFIX).get(ClientResponse.class);
            checkStatusCode(response);

            final Charset encoding = extractEncoding(response);
            final String xmlResponseBody = response.getEntity(String.class);

            final Document metadataDocument = SecureXmlParserFactory.newDocumentBuilder().parse(ioUtilsWrapper.toInputStream(xmlResponseBody, encoding));
            final XPath xPath = XPathFactory.newInstance().newXPath();

            final XPathExpression schemaXpath = xPath.compile(SCHEMA_XPATH);
            final Node schemaNode = (Node) schemaXpath.evaluate(metadataDocument, XPathConstants.NODE);

            XPathExpression deltaGroupsQueryXpath = xPath.compile(DELTA_RETURN_PATHS_XPATH);
            // Even though the constant is named NODESET, the actual type returned by this evaluation is a NodeList
            final NodeList deltaReturnTypes = (NodeList) deltaGroupsQueryXpath.evaluate(metadataDocument, XPathConstants.NODESET);

            return supportsUsersAndGroupsDeltaQuery(schemaNode, deltaReturnTypes);
        } catch (XPathExpressionException | IOException | SAXException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * As per the <a href="http://docs.oasis-open.org/odata/odata/v4.0/errata02/os/complete/part3-csdl/odata-v4.0-errata02-os-part3-csdl-complete.html#_Nominal_Types">OData Metadata documentation</a>
     * a nominal type (in our case the returned type of delta queries) must be prefixed by either the Namespace or the Alias.
     * This method checks if the delta query return types contain at least one valid user type and at least one valid group type,
     * where valid means prefixed by either the namespace or the alias.
     * @param schemaNode the Node representing the Schema element in the metadata document
     * @param deltaReturnTypes a NodeList containing ReturnType elements for delta query functions
     * @return true if the deltaReturnTypes contains at least one valid user type and at least one valid group type, false otherwise
     */
    private boolean supportsUsersAndGroupsDeltaQuery(final Node schemaNode, final NodeList deltaReturnTypes) {
        final Set<String> possiblePrefixes = getPresentAttributeValues(schemaNode, ALIAS_ATTRIBUTE_NAME, NAMESPACE_ATTRIBUTE_NAME);
        final Set<String> userTypes = possiblePrefixes.stream().map(a -> String.format(COLLECTION_TYPE_FORMAT, a, USER_SUFFIX)).collect(Collectors.toSet());
        final Set<String> groupTypes = possiblePrefixes.stream().map(a -> String.format(COLLECTION_TYPE_FORMAT, a, GROUP_SUFFIX)).collect(Collectors.toSet());

        final Set<String> types = IntStream.range(0, deltaReturnTypes.getLength())
                .mapToObj(deltaReturnTypes::item)
                .flatMap(node -> getPresentAttributeValues(node, RETURN_TYPE_ATTRIBUTE_NAME).stream())
                .collect(Collectors.toSet());

        return !Collections.disjoint(userTypes, types) && !Collections.disjoint(groupTypes, types);
    }

    private Set<String> getPresentAttributeValues(Node node, String... names) {
        final Optional<NamedNodeMap> attributes = Optional.ofNullable(node).map(Node::getAttributes);
        if (!attributes.isPresent()) {
            return Collections.emptySet();
        }
        final Set<String> result = new HashSet<>();
        for (String name : names) {
            attributes.map(a -> a.getNamedItem(name)).map(Node::getNodeValue).ifPresent(result::add);
        }
        return result;
    }

    private void checkStatusCode(ClientResponse response) {
        // This check mirrors Jersey's exception handling
        if (response.getStatus() >= 300) {
            final ClientRequestImpl request = new ClientRequestImpl(response.getLocation(), "GET");
            throw new UniformInterfaceException(response, request.getPropertyAsFeature(ClientConfig.PROPERTY_BUFFER_RESPONSE_ENTITY_ON_EXCEPTION, true));
        }
    }

    private Charset extractEncoding(ClientResponse response) {
        return response.getType().getParameters().entrySet().stream()
                .filter(entry -> entry.getKey().equals(CHARSET_PARAMETER_NAME))
                .findFirst()
                .map(entry -> Charset.forName(entry.getValue()))
                .orElse(Charsets.UTF_8);
    }

    @VisibleForTesting
    public String getGraphBaseResource() {
        return graphBaseEndpoint + GRAPH_API_VERSION;
    }

    public <T extends PageableGraphList> T getNextPage(final String nextLink, final Class<T> resultsClass) throws OperationFailedException {
        return handleRequest(() -> loggingResource(client.resource(nextLink)).accept(MediaType.APPLICATION_JSON_TYPE).get(resultsClass));
    }

    public <T extends PageableGraphList> T getNextPage(final String nextLink, final Class<T> resultsClass, final ODataTop limit) throws OperationFailedException {
        final URI nextLinkWithUpdatedLimit = UriBuilder.fromUri(nextLink).replaceQueryParam(ODataTop.QUERY_PARAM_NAME, limit.asRawValue()).build();
        return handleRequest(() -> loggingResource(client.resource(nextLinkWithUpdatedLimit)).accept(MediaType.APPLICATION_JSON_TYPE).get(resultsClass));
    }

    @VisibleForTesting
    public <T> T handleRequest(Supplier<T> requestSupplier) throws OperationFailedException {
        try {
            return requestSupplier.get();
        } catch (UniformInterfaceException e) {
            String message = String.format("Microsoft Graph API has returned an error response. Response status code: %d, content %s",
                    e.getResponse().getStatus(),
                    e.getResponse().getEntity(String.class)
            );
            throw new OperationFailedException(message, e);
        }
    }

    private WebResource loggingResource(WebResource baseResource) {
        baseResource.addFilter(new JerseyLoggingFilter());
        return baseResource;
    }
}
