package com.atlassian.braid.source;

import com.atlassian.braid.Link;
import com.atlassian.braid.SchemaSource;
import com.atlassian.braid.transformation.BraidSchemaSource;
import graphql.analysis.QueryTraversal;
import graphql.analysis.QueryVisitor;
import graphql.analysis.QueryVisitorFieldEnvironment;
import graphql.analysis.QueryVisitorStub;
import graphql.language.Field;
import graphql.language.FragmentDefinition;
import graphql.language.FragmentSpread;
import graphql.language.Node;
import graphql.language.NodeTraverser;
import graphql.language.NodeVisitorStub;
import graphql.language.SelectionSet;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.GraphQLFieldsContainer;
import graphql.schema.GraphQLObjectType;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static graphql.analysis.QueryTraversal.newQueryTraversal;

public class TrimFieldsSelection {


    public static List<FragmentDefinition> trimFieldSelection(SchemaSource schemaSource, DataFetchingEnvironment environment, Node root, boolean ignoreFirstField) {
        List<FieldWithLink> clearSubSelectionAndRenameField = new ArrayList<>();
        List<FieldWithLink> clearSubSelection = new ArrayList<>();
        Map<Field, SelectionSet> parentSelectionSets = new LinkedHashMap<>();

        BraidSchemaSource braidSchemaSource = new BraidSchemaSource(schemaSource);

        QueryVisitor nodeVisitor = new QueryVisitorStub() {

            @Override
            public void visitField(QueryVisitorFieldEnvironment env) {
                if (env.isTypeNameIntrospectionField()) {
                    return;
                }
                GraphQLFieldsContainer parentFieldsContainer = env.getFieldsContainer();
                Field field = env.getField();
                boolean isFirstField = env.getParentEnvironment() == null;
                if (isFirstField && ignoreFirstField) {
                    return;
                }
                Optional<Link> linkWithSeparateInputField = getLinkWithSeparateInputField(
                        braidSchemaSource, schemaSource.getLinks(), parentFieldsContainer.getName(), field.getName());
                Optional<Link> linkWithFieldUsedAsInput = getLinkWithFieldUsedAsInput(schemaSource.getLinks(), parentFieldsContainer.getName(), field.getName());
                linkWithSeparateInputField.ifPresent(link -> {
                            clearSubSelectionAndRenameField.add(new FieldWithLink(field, link));
                            if (env.getParentEnvironment() == null) {
                                return;
                            }
                            parentSelectionSets.put(field, env.getParentEnvironment().getField().getSelectionSet());
                            env.getParentEnvironment().getField().getSelectionSet();
                        }
                );
                linkWithFieldUsedAsInput.ifPresent(link -> clearSubSelection.add(new FieldWithLink(field, link)));
            }
        };

        Map<String, FragmentDefinition> fragmentsByName = environment.getFragmentsByName().entrySet()
                .stream().collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().deepCopy()));


        QueryTraversal queryTraversal = newQueryTraversal()
                .schema(environment.getGraphQLSchema())
                .root(root)
                .rootParentType((GraphQLObjectType) environment.getParentType())
                .fragmentsByName(fragmentsByName)
                .variables(environment.getExecutionContext().getVariables()).build();
        queryTraversal.visitPreOrder(nodeVisitor);

        removeDuplicationsCausedByInputFields(clearSubSelectionAndRenameField, parentSelectionSets);

        clearSubSelection.forEach(fieldWithLink -> fieldWithLink.field.setSelectionSet(null));
        clearSubSelectionAndRenameField.forEach(fieldWithLink -> {
            fieldWithLink.field.setSelectionSet(null);
            fieldWithLink.field.setName(fieldWithLink.link.getSourceInputFieldName());
        });


        Set<FragmentDefinition> referencedFragments = new LinkedHashSet<>();
        getReferencedFragments(root, fragmentsByName, referencedFragments);
        return new ArrayList<>(referencedFragments);

    }

    private static void removeDuplicationsCausedByInputFields(List<FieldWithLink> clearSubSelectionAndRenameField, Map<Field, SelectionSet> parentSelectionSets) {
        Map<SelectionSet, List<Field>> fieldsToRemove = new LinkedHashMap<>();
        for (FieldWithLink fieldWithLink : clearSubSelectionAndRenameField) {
            SelectionSet selectionSet = parentSelectionSets.get(fieldWithLink.field);
            if (selectionSet == null) {
                continue;
            }
            String sourceInputFieldName = fieldWithLink.link.getSourceInputFieldName();
            String currentFieldAlias = fieldWithLink.field.getAlias();
            if (selectionSetContainsField(selectionSet, sourceInputFieldName, currentFieldAlias)) {
                fieldsToRemove.computeIfAbsent(selectionSet, notUsed -> new ArrayList<>());
                fieldsToRemove.get(selectionSet).add(fieldWithLink.field);
            }
        }
        fieldsToRemove.forEach((selectionSet, fields) -> selectionSet.getSelections().removeAll(fields));
    }

    private static boolean selectionSetContainsField(SelectionSet selectionSet, String name, String alias) {
        return selectionSet.getSelections().stream()
                .filter(selection -> selection instanceof Field)
                .map(field -> (Field) field)
                .anyMatch(field -> field.getName().equals(name) && (Objects.equals(field.getAlias(), alias)));
    }

    /**
     * Recursively searches for fragments starting from the given root node
     *
     * @param root                  - The node to look for references in
     * @param fragmentDefinitionMap - the map of defined fragments in the query keyed by name
     * @param referencedFragments   - The set of already known referenced fragments
     */
    private static void getReferencedFragments(Node root,
                                               Map<String, FragmentDefinition> fragmentDefinitionMap,
                                               Set<FragmentDefinition> referencedFragments) {
        Set<FragmentDefinition> childFragments = new LinkedHashSet<>();
        NodeVisitorStub nodeVisitorStub = new NodeVisitorStub() {
            @Override
            public TraversalControl visitFragmentSpread(FragmentSpread fragmentSpread, TraverserContext<Node> context) {
                childFragments.add(fragmentDefinitionMap.get(fragmentSpread.getName()));
                return TraversalControl.CONTINUE;
            }
        };
        new NodeTraverser().preOrder(nodeVisitorStub, root);
        childFragments.stream()
                .filter(referencedFragments::add)
                .forEach(frag -> getReferencedFragments(frag, fragmentDefinitionMap, referencedFragments));
    }

    private static Optional<Link> getLinkWithSeparateInputField(BraidSchemaSource braidSchemaSource, Collection<Link> links, String typeName, String fieldName) {
        return links.stream()
                .filter(l -> braidSchemaSource.getLinkBraidSourceType(l).equals(typeName)
                        && l.getNewFieldName().equals(fieldName)
                        && !l.getSourceInputFieldName().equals(fieldName))
                .findFirst();
    }

    private static Optional<Link> getLinkWithFieldUsedAsInput(Collection<Link> links, String typeName, String fieldName) {
        return links.stream()
                .filter(l -> l.getSourceType().equals(typeName))
                .filter(l -> l.getSourceInputFieldName().equals(fieldName))
                .findFirst();
    }

    private static class FieldWithLink {
        public Field field;
        public Link link;

        public FieldWithLink(Field field, Link link) {
            this.field = field;
            this.link = link;
        }
    }


}

