package com.atlassian.braid;

import com.atlassian.braid.graphql.language.AliasablePropertyDataFetcher;
import com.atlassian.braid.source.LinkMutation;
import com.atlassian.braid.source.TopLevelFieldMutation;
import graphql.execution.DataFetcherResult;
import graphql.language.FieldDefinition;
import graphql.language.ListType;
import graphql.language.NonNullType;
import graphql.language.ObjectTypeDefinition;
import graphql.language.OperationTypeDefinition;
import graphql.language.SchemaDefinition;
import graphql.language.Type;
import graphql.language.TypeDefinition;
import graphql.language.TypeName;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.GraphQLSchema;
import graphql.schema.idl.RuntimeWiring;
import graphql.schema.idl.SchemaGenerator;
import graphql.schema.idl.TypeDefinitionRegistry;
import org.dataloader.BatchLoader;
import org.dataloader.DataLoader;
import org.dataloader.DataLoaderRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.atlassian.braid.TypeUtils.DEFAULT_MUTATION_TYPE_NAME;
import static com.atlassian.braid.TypeUtils.DEFAULT_QUERY_TYPE_NAME;
import static com.atlassian.braid.TypeUtils.MUTATION_FIELD_NAME;
import static com.atlassian.braid.TypeUtils.QUERY_FIELD_NAME;
import static com.atlassian.braid.TypeUtils.addMutationTypeToSchema;
import static com.atlassian.braid.TypeUtils.addQueryTypeToSchema;
import static com.atlassian.braid.TypeUtils.createDefaultQueryTypeDefinition;
import static com.atlassian.braid.TypeUtils.findMutationType;
import static com.atlassian.braid.TypeUtils.findQueryType;
import static com.atlassian.braid.java.util.BraidCollectors.singleton;
import static graphql.schema.DataFetchingEnvironmentBuilder.newDataFetchingEnvironment;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Stream.concat;

final class BraidSchema {

    private static final Logger log = LoggerFactory.getLogger(BraidSchema.class);

    private final GraphQLSchema schema;
    private final Map<String, BatchLoader> batchLoaders;

    private BraidSchema(GraphQLSchema schema, Map<String, BatchLoader> batchLoaders) {
        this.schema = requireNonNull(schema);
        this.batchLoaders = requireNonNull(batchLoaders);
    }


    static BraidSchema from(TypeDefinitionRegistry typeDefinitionRegistry,
                            RuntimeWiring.Builder runtimeWiringBuilder,
                            List<SchemaSource> schemaSources) {

        final Map<SchemaNamespace, BraidSchemaSource> dataSourceTypes = toBraidSchemaSourceMap(schemaSources);

        findSchemaDefinitionOrCreateOne(typeDefinitionRegistry);

        final ObjectTypeDefinition queryObjectTypeDefinition =
                findQueryType(typeDefinitionRegistry)
                        .orElseGet(() -> addQueryTypeToSchema(typeDefinitionRegistry, createDefaultQueryTypeDefinition()));

        final ObjectTypeDefinition mutationObjectTypeDefinition =
                findMutationType(typeDefinitionRegistry)
                        .orElseGet(TypeUtils::createDefaultMutationTypeDefinition);

        final Map<String, BatchLoader> batchLoaders =
                addDataSources(dataSourceTypes, typeDefinitionRegistry, runtimeWiringBuilder, queryObjectTypeDefinition, mutationObjectTypeDefinition);

        if (!mutationObjectTypeDefinition.getFieldDefinitions().isEmpty()) {
            addMutationTypeToSchema(typeDefinitionRegistry, mutationObjectTypeDefinition);
        }
        final GraphQLSchema graphQLSchema = new SchemaGenerator()
                .makeExecutableSchema(typeDefinitionRegistry, runtimeWiringBuilder.build());

        return new BraidSchema(graphQLSchema, batchLoaders);
    }

    private static void findSchemaDefinitionOrCreateOne(TypeDefinitionRegistry typeDefinitionRegistry) {
        typeDefinitionRegistry.schemaDefinition()
                .orElseGet(() -> createDefaultSchemaDefinition(typeDefinitionRegistry));
    }

    private static SchemaDefinition createDefaultSchemaDefinition(TypeDefinitionRegistry typeDefinitionRegistry) {
        SchemaDefinition.Builder builder = SchemaDefinition.newSchemaDefinition();

        typeDefinitionRegistry.getType(DEFAULT_QUERY_TYPE_NAME)
                .ifPresent(__ -> addOperation(builder, QUERY_FIELD_NAME, DEFAULT_QUERY_TYPE_NAME));

        typeDefinitionRegistry.getType(DEFAULT_MUTATION_TYPE_NAME)
                .ifPresent(__ -> addOperation(builder, MUTATION_FIELD_NAME, DEFAULT_MUTATION_TYPE_NAME));

        SchemaDefinition schemaDefinition = builder.build();
        typeDefinitionRegistry.add(schemaDefinition);
        return schemaDefinition;
    }

    private static void addOperation(SchemaDefinition.Builder schemaDefinition, String queryFieldName, String defaultQueryTypeName) {
        schemaDefinition.operationTypeDefinition(new OperationTypeDefinition(queryFieldName, new TypeName(defaultQueryTypeName)));
    }

    private static Map<String, BatchLoader> addDataSources(Map<SchemaNamespace, BraidSchemaSource> dataSources,
                                                           TypeDefinitionRegistry registry,
                                                           RuntimeWiring.Builder runtimeWiringBuilder,
                                                           ObjectTypeDefinition queryObjectTypeDefinition,
                                                           ObjectTypeDefinition mutationObjectTypeDefinition) {
        addAllNonOperationTypes(dataSources, registry, runtimeWiringBuilder);

        final List<FieldDataLoaderRegistration> linkedTypesBatchLoaders = linkTypes(dataSources,
                queryObjectTypeDefinition, mutationObjectTypeDefinition);

        final List<FieldDataLoaderRegistration> queryFieldsBatchLoaders =
                addSchemaSourcesTopLevelFieldsToOperation(dataSources, queryObjectTypeDefinition, BraidSchemaSource::getQueryType, BraidSchemaSource::getQueryFieldAlias);

        final List<FieldDataLoaderRegistration> mutationFieldsBatchLoaders =
                addSchemaSourcesTopLevelFieldsToOperation(dataSources, mutationObjectTypeDefinition, BraidSchemaSource::getMutationType, BraidSchemaSource::getMutationFieldAliases);

        Map<String, BatchLoader> loaders = new HashMap<>();

        concat(linkedTypesBatchLoaders.stream(),
                concat(queryFieldsBatchLoaders.stream(),
                        mutationFieldsBatchLoaders.stream())).forEach(r -> {
            String key = getDataLoaderKey(r.type, r.field);
            BatchLoader linkBatchLoader = loaders.get(key);
            if (linkBatchLoader != null) {
                loaders.put(key + "-link", linkBatchLoader);
            }

            runtimeWiringBuilder.type(r.type, wiring -> wiring.dataFetcher(r.field, new BraidDataFetcher(key)));
            loaders.put(key, r.loader);
        });
        return loaders;
    }

    Map<String, BatchLoader> getBatchLoaders() {
        return Collections.unmodifiableMap(batchLoaders);
    }

    public GraphQLSchema getSchema() {
        return schema;
    }

    private static class BraidDataFetcher implements DataFetcher {
        private final String dataLoaderKey;

        private BraidDataFetcher(String dataLoaderKey) {
            this.dataLoaderKey = requireNonNull(dataLoaderKey);
        }

        @Override
        public Object get(DataFetchingEnvironment env) {
            final DataLoaderRegistry registry = getDataLoaderRegistry(env);
            final Object loadedValue = registry.getDataLoader(dataLoaderKey).load(env);

            // allows a top level field to also be linked
            return Optional.ofNullable(registry.getDataLoader(dataLoaderKey + "-link"))
                    .map(l -> loadFromLinkLoader(env, loadedValue, l))
                    .orElse(loadedValue);
        }

        private static Object loadFromLinkLoader(DataFetchingEnvironment env,
                                                 Object source,
                                                 DataLoader<Object, Object> dataLoader) {
            return dataLoader.load(newDataFetchingEnvironment(env)
                    .source(source)
                    .fieldDefinition(env.getFieldDefinition())
                    .build());
        }

        private static DataLoaderRegistry getDataLoaderRegistry(DataFetchingEnvironment env) {
            return getContext(env).getDataLoaderRegistry();
        }

        private static BraidContext getContext(DataFetchingEnvironment env) {
            return env.getContext();
        }
    }


    private static void addAllNonOperationTypes(Map<SchemaNamespace, BraidSchemaSource> dataSources,
                                                TypeDefinitionRegistry registry,
                                                RuntimeWiring.Builder runtimeWiringBuilder) {

        final Map<String, List<BraidTypeDefinition>> allNonOperationTypeDefinitions = dataSources.values().stream()
                .map(BraidSchemaSource::getNonOperationTypes)
                .flatMap(Collection::stream)
                .collect(groupingBy(BraidTypeDefinition::getName));

        final List<List<BraidTypeDefinition>> duplicateTypes =
                allNonOperationTypeDefinitions.values().stream()
                        .filter(e -> e.size() > 1)
                        .collect(toList());

        if (!duplicateTypes.isEmpty()) {
            duplicateTypes.stream().flatMap(Collection::stream)
                    .forEach(c -> System.out.printf("Type `%s` from %s is in conflict\n", c.getName(), c.getNamespace()));
            throw new IllegalStateException("Type name conflict exists");
        }

        allNonOperationTypeDefinitions.values().stream()
                .map(types -> types.get(0))
                .peek(type -> wireFieldDefinitions(runtimeWiringBuilder, type.getType(), type.getFieldDefinitions()))
                .map(BraidTypeDefinition::getType)
                .forEach(registry::add);
    }

    private static void wireFieldDefinitions(RuntimeWiring.Builder runtimeWiringBuilder,
                                             TypeDefinition type,
                                             List<FieldDefinition> fieldDefinitions) {
        fieldDefinitions.forEach(fd ->
                runtimeWiringBuilder.type(
                        type.getName(),
                        wiring -> wiring.dataFetcher(fd.getName(), new AliasablePropertyDataFetcher(fd.getName()))));
    }

    private static List<FieldDataLoaderRegistration> addSchemaSourcesTopLevelFieldsToOperation
            (Map<SchemaNamespace, BraidSchemaSource> dataSources,
             ObjectTypeDefinition braidOperationType,
             Function<BraidSchemaSource, Optional<ObjectTypeDefinition>> findOperationType,
             BiFunction<BraidSchemaSource, String, Optional<FieldAlias>> getFieldAlias) {
        return dataSources.values()
                .stream()
                .map(source -> addSchemaSourceTopLevelFieldsToOperation(source, braidOperationType, findOperationType, getFieldAlias))
                .flatMap(Collection::stream)
                .collect(toList());
    }

    private static List<FieldDataLoaderRegistration> addSchemaSourceTopLevelFieldsToOperation(
            BraidSchemaSource source,
            ObjectTypeDefinition braidOperationType,
            Function<BraidSchemaSource, Optional<ObjectTypeDefinition>> findOperationType,
            BiFunction<BraidSchemaSource, String, Optional<FieldAlias>> getFieldAlias) {

        return findOperationType.apply(source)
                .map(operationType -> addSchemaSourceTopLevelFieldsToOperation(source, braidOperationType, operationType, getFieldAlias))
                .orElse(emptyList());
    }

    private static List<FieldDataLoaderRegistration> addSchemaSourceTopLevelFieldsToOperation(
            BraidSchemaSource schemaSource,
            ObjectTypeDefinition braidOperationType,
            ObjectTypeDefinition sourceOperationType,
            BiFunction<BraidSchemaSource, String, Optional<FieldAlias>> getFieldAlias) {

        // todo: smarter merge, optional namespacing, etc
        final List<BraidFieldDefinition> fieldDefinitions = aliasedFieldDefinitions(schemaSource, sourceOperationType, getFieldAlias);

        final List<FieldDefinition> braidOperationTypeFieldDefinitions = braidOperationType.getFieldDefinitions();
        fieldDefinitions.forEach(bfd -> braidOperationTypeFieldDefinitions.add(bfd.definition));

        return wireOperationFields(braidOperationType.getName(), schemaSource, fieldDefinitions);
    }

    private static List<BraidFieldDefinition> aliasedFieldDefinitions(BraidSchemaSource schemaSource, ObjectTypeDefinition sourceOperationType, BiFunction<BraidSchemaSource, String, Optional<FieldAlias>> getFieldAlias) {
        return sourceOperationType.getFieldDefinitions().stream()
                .map(definition -> getFieldAlias.apply(schemaSource, definition.getName()).map(alias -> new BraidFieldDefinition(alias, definition)))
                .filter(Optional::isPresent)
                .map(Optional::get)
                .map(def -> aliasedFieldDefinition(schemaSource, def))
                .collect(toList());
    }

    private static final class BraidFieldDefinition {
        private final FieldAlias alias;
        private final FieldDefinition definition;

        private BraidFieldDefinition(FieldAlias alias, FieldDefinition definition) {
            this.alias = alias;
            this.definition = definition;
        }
    }

    private static BraidFieldDefinition aliasedFieldDefinition(BraidSchemaSource schemaSource, BraidFieldDefinition braidFieldDefinition) {
        final FieldDefinition definition = braidFieldDefinition.definition;
        Type aliasedType = schemaSource.aliasType(definition.getType());
        return new BraidFieldDefinition(
                braidFieldDefinition.alias,
                FieldDefinition.newFieldDefinition()
                .name( braidFieldDefinition.alias.getBraidName())
                .type( aliasedType)
                .inputValueDefinitions(schemaSource.aliasInputValueDefinitions(definition.getInputValueDefinitions()))
                .directives(definition.getDirectives()).build());
    }

    private static List<FieldDataLoaderRegistration> wireOperationFields(String typeName,
                                                                         BraidSchemaSource schemaSource,
                                                                         List<BraidFieldDefinition> fieldDefinitions) {
        return fieldDefinitions.stream()
                .map(queryField -> wireOperationField(typeName, schemaSource, queryField))
                .collect(toList());
    }

    private static FieldDataLoaderRegistration wireOperationField(
            String typeName,
            BraidSchemaSource schemaSource,
            BraidFieldDefinition operationField) {

        BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>> batchLoader =
                newBatchLoader(schemaSource.getSchemaSource(), new TopLevelFieldMutation(operationField.alias));

        return new FieldDataLoaderRegistration(typeName, operationField.alias.getBraidName(), batchLoader);
    }

    private static List<FieldDataLoaderRegistration> linkTypes(Map<SchemaNamespace, BraidSchemaSource> sources,
                                                               ObjectTypeDefinition queryObjectTypeDefinition,
                                                               ObjectTypeDefinition mutationObjectTypeDefinition) {
        List<FieldDataLoaderRegistration> fieldDataLoaderRegistrations = new ArrayList<>();
        for (BraidSchemaSource source : sources.values()) {
            TypeDefinitionRegistry typeRegistry = source.getTypeRegistry();

            Map<String, TypeDefinition> dsTypes = new HashMap<>(typeRegistry.types());

            for (Link link : source.getSchemaSource().getLinks()) {
                // replace the field's type
                ObjectTypeDefinition typeDefinition = getObjectTypeDefinition(queryObjectTypeDefinition,
                        mutationObjectTypeDefinition, typeRegistry, dsTypes, source.getLinkBraidSourceType(link));

                validateSourceFromFieldExists(link, typeDefinition);

                Optional<FieldDefinition> sourceField = typeDefinition.getFieldDefinitions().stream()
                        .filter(d -> d.getName().equals(link.getSourceField()))
                        .findFirst();

                Optional<FieldDefinition> sourceFromField = typeDefinition.getFieldDefinitions()
                        .stream()
                        .filter(Objects::nonNull)
                        .filter(s -> s.getName().equals(link.getSourceFromField()))
                        .findAny();

                if (link.isReplaceFromField()) {
                    typeDefinition.getFieldDefinitions().remove(sourceFromField.get());
                }

                BraidSchemaSource targetSource = sources.get(link.getTargetNamespace());
                if (targetSource == null) {
                    throw new IllegalArgumentException("Can't find target schema source: " + link.getTargetNamespace());
                }
                if (!targetSource.hasType(link.getTargetType())) {
                    throw new IllegalArgumentException("Can't find target type: " + link.getTargetType());

                }

                Type targetType = new TypeName(link.getTargetType());

                if (!sourceField.isPresent()) {
                    // Add source field to schema if not already there
                    if (sourceFromField.isPresent() && isListType(sourceFromField.get().getType())) {
                        targetType = new ListType(targetType);
                    }
                    FieldDefinition field = new FieldDefinition(link.getSourceField(), targetType);
                    typeDefinition.getFieldDefinitions().add(field);
                } else if (isListType(sourceField.get().getType())) {
                    if (sourceField.get().getType() instanceof NonNullType) {
                        sourceField.get().setType(new NonNullType(new ListType(targetType)));
                    } else {
                        sourceField.get().setType(new ListType(targetType));
                    }
                } else {
                    // Change source field type to the braided type
                    sourceField.get().setType(targetType);
                }

                fieldDataLoaderRegistrations.add(new FieldDataLoaderRegistration(
                        source.getLinkBraidSourceType(link),
                        link.getSourceField(),
                        newBatchLoader(targetSource.getSchemaSource(), new LinkMutation(link))));
            }
        }
        return fieldDataLoaderRegistrations;
    }

    private static boolean isListType(Type type) {
        return type instanceof ListType ||
                (type instanceof NonNullType && ((NonNullType) type).getType() instanceof ListType);
    }

    private static ObjectTypeDefinition getObjectTypeDefinition(ObjectTypeDefinition queryObjectTypeDefinition,
                                                                ObjectTypeDefinition mutationObjectTypeDefinition,
                                                                TypeDefinitionRegistry typeRegistry,
                                                                Map<String, TypeDefinition> dsTypes,
                                                                String linkSourceType) {
        ObjectTypeDefinition typeDefinition = (ObjectTypeDefinition) dsTypes.get(linkSourceType);
        if (typeDefinition == null && linkSourceType.equals(queryObjectTypeDefinition.getName())) {
            typeDefinition = findQueryType(typeRegistry).orElse(null);
            if (typeDefinition == null && linkSourceType.equals(mutationObjectTypeDefinition.getName())) {
                typeDefinition = findMutationType(typeRegistry).orElse(null);
            }
        }

        if (typeDefinition == null) {
            throw new IllegalArgumentException("Can't find source type: " + linkSourceType);
        }
        return typeDefinition;
    }

    private static String getDataLoaderKey(String sourceType, String sourceField) {
        return sourceType + "." + sourceField;
    }

    private static BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>> newBatchLoader(
            SchemaSource schemaSource, FieldMutation fieldMutation) {
        // We use DataFetchingEnvironment as the key in the BatchLoader because different fetches of the object may
        // request different fields. Someday we may smartly combine them into one somehow, but that day isn't today.
        return schemaSource.newBatchLoader(schemaSource, fieldMutation);
    }

    private static void validateSourceFromFieldExists(Link link, ObjectTypeDefinition typeDefinition) {
        //noinspection ResultOfMethodCallIgnored
        typeDefinition.getFieldDefinitions().stream()
                .filter(d -> d.getName().equals(link.getSourceFromField()))
                .findFirst()
                .orElseThrow(() ->
                        new IllegalArgumentException(
                                format("Can't find source from field: %s", link.getSourceFromField())));
    }

    private static Map<SchemaNamespace, BraidSchemaSource> toBraidSchemaSourceMap(List<SchemaSource> schemaSources) {
        return schemaSources.stream()
                .map(BraidSchemaSource::new)
                .collect(groupingBy(BraidSchemaSource::getNamespace, singleton()));
    }

    private static class FieldDataLoaderRegistration {
        private final String type;
        private final String field;
        private final BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>> loader;

        private FieldDataLoaderRegistration(String type, String field,
                                            BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>> loader) {
            this.type = type;
            this.field = field;
            this.loader = loader;
        }
    }
}
