package com.atlassian.braid.transformation;

import com.atlassian.braid.Extension;
import com.atlassian.braid.SchemaSource;
import com.atlassian.braid.graphql.language.KeyedDataFetchingEnvironment;
import com.atlassian.braid.java.util.BraidObjects;
import graphql.execution.DataFetcherResult;
import graphql.execution.MergedField;
import graphql.language.Argument;
import graphql.language.Directive;
import graphql.language.Field;
import graphql.language.FieldDefinition;
import graphql.language.ObjectTypeDefinition;
import graphql.language.StringValue;
import graphql.language.TypeDefinition;
import graphql.language.TypeName;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLSchema;
import graphql.schema.idl.TypeDefinitionRegistry;
import org.dataloader.BatchLoader;

import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.atlassian.braid.TypeUtils.unwrap;
import static graphql.schema.DataFetchingEnvironmentImpl.newDataFetchingEnvironment;
import static java.util.Collections.singletonMap;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet;

/**
 * A {@link SchemaTransformation} for processing extensions, which add fields to source object types. The fields to add
 * are the ones in the target object type, which is specified in the {@link Extension}. The field values are loaded
 * from a top-level query field of the target schema source.
 */
public class ExtensionSchemaTransformation implements SchemaTransformation {
    private static Directive deprecated = Directive.newDirective()
            .name("deprecated")
            .argument(Argument.newArgument()
                    .name("reason")
                    .value(StringValue.newStringValue("Extension field").build())
                    .build())
            .build();

    @Override
    public Map<String, BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>>> transform(BraidingContext ctx) {
        TypeDefinitionRegistry typeRegistry = ctx.getRegistry();
        ObjectTypeDefinition queryType = ctx.getQueryObjectTypeDefinition();

        // Compute the extension of a top-level field's type.
        return BraidTypeDefinition.getFieldDefinitions(queryType).stream()
                .flatMap(topLevelField -> ctx.getDataSources().values().stream()
                        .filter(braidSchemaSource -> braidSchemaSource.hasTypeAndField(typeRegistry, queryType, topLevelField))
                        .flatMap(braidSchemaSource -> {
                            String topLevelFieldType = braidSchemaSource.getSourceTypeName(unwrap(topLevelField.getType()));
                            return braidSchemaSource.getExtensions(topLevelFieldType).stream()
                                    .map(ext -> mergeType(braidSchemaSource, ctx, queryType, topLevelField, ext))
                                    .flatMap(m -> m.entrySet().stream());
                        }))
                .collect(toMap(Map.Entry::getKey, Map.Entry::getValue));
    }

    private Map<String, BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>>> mergeType(BraidSchemaSource braidSchemaSource,
                                                                                                   BraidingContext ctx,
                                                                                                   TypeDefinition containingType,
                                                                                                   FieldDefinition field,
                                                                                                   Extension ext) {
        ObjectTypeDefinition originalType = findRequiredOriginalType(ctx, braidSchemaSource, field);

        Set<String> originalTypeFieldNames = originalType.getFieldDefinitions().stream().map(FieldDefinition::getName).collect(toSet());

        BraidSchemaSource targetSource = ctx.getDataSources().get(ext.getBy().getNamespace());

        ObjectTypeDefinition targetType = findRequiredTargetType(targetSource, ext);

        String key = "ext-" + containingType.getName();

        // Add all the fields of the target object type to the source object type.
        wireNewFields(ctx, braidSchemaSource, containingType, field, ext, originalType, originalTypeFieldNames, targetType, key);

        // All the fields added share a common BatchLoader that loads the top-level field of the target schema source.
        SchemaSource schemaSource = ctx.getDataSources().get(ext.getBy().getNamespace()).getSchemaSource();
        return singletonMap(key, schemaSource.newBatchLoader(schemaSource, new ExtensionTransformation(ext), ctx.getBatchLoaderEnvironment()));
    }

    private ObjectTypeDefinition findRequiredTargetType(BraidSchemaSource targetSource, Extension ext) {
        return (ObjectTypeDefinition) targetSource.getType(ext.getBy().getType()).orElseThrow(IllegalAccessError::new);
    }

    private ObjectTypeDefinition findRequiredOriginalType(BraidingContext ctx, BraidSchemaSource braidSchemaSource, FieldDefinition fieldDef) {
        return (ObjectTypeDefinition) ctx.getRegistry().getType(braidSchemaSource.getBraidTypeName(unwrap(fieldDef.getType())))
                .orElseThrow(IllegalArgumentException::new);
    }

    /**
     * Adds all the fields of the target object type to the source object type and register DataFetchers for them.
     */
    private void wireNewFields(BraidingContext ctx,
                               BraidSchemaSource braidSchemaSource,
                               TypeDefinition containingType,
                               FieldDefinition field,
                               Extension ext,
                               ObjectTypeDefinition originalType,
                               Set<String> originalTypeFieldNames,
                               ObjectTypeDefinition targetType,
                               String key) {
        List<FieldDefinition> fieldDefinitions = new ArrayList<>(originalType.getFieldDefinitions());
        targetType.getFieldDefinitions().stream()
                // Fields that already exist in the original field type are omitted.
                .filter(extField -> !originalTypeFieldNames.contains(extField.getName()))
                .forEach(extField -> {
                    fieldDefinitions.add(extField.transform(builder -> builder.directive(deprecated)));
                    String braidType = braidSchemaSource.getBraidTypeName(ext.getType());
                    String extFieldName = extField.getName();
                    DataFetcher<?> extFieldFetcher = buildDataFetcher(braidSchemaSource, containingType, field, key, extFieldName);
                    ctx.registerDataFetcher(braidType, extFieldName, extFieldFetcher);
                });
        ObjectTypeDefinition modifiedOriginalType = originalType.transform(
                builder -> builder
                        .fieldDefinitions(fieldDefinitions)
        );
        ctx.getRegistry().remove(originalType);
        ctx.getRegistry().add(modifiedOriginalType);
    }

    /**
     * The DataFetcher for the extension fields share a common BatchLoader used by loading the top-level field
     * containing the extension fields.
     */
    private DataFetcher buildDataFetcher(BraidSchemaSource braidSchemaSource,
                                         TypeDefinition containingType,
                                         FieldDefinition field,
                                         String key,
                                         String extFieldName) {
        return env -> env.getDataLoader(key)
                // Load the top-level field containing the extension fields
                .load(new KeyedDataFetchingEnvironment(updateDataFetchingEnvironment(braidSchemaSource, containingType, field, env)))
                .thenApply(BraidObjects::<DataFetcherResult<Map<String, Object>>>cast)
                // Get the individual extension field value from the containing top-level field value.
                .thenApply(dfr -> nullSafeGetFieldValue(dfr, extFieldName));
    }

    private static Object nullSafeGetFieldValue(@Nullable DataFetcherResult<Map<String, Object>> dfr, String fieldName) {
        return Optional.ofNullable(dfr)
                .flatMap(r -> Optional.ofNullable(r.getData()))
                .map(data -> data.get(fieldName))
                .orElse(null);
    }

    private static DataFetchingEnvironment updateDataFetchingEnvironment(BraidSchemaSource braidSchemaSource,
                                                                         TypeDefinition containingType,
                                                                         FieldDefinition field,
                                                                         DataFetchingEnvironment env) {
        final GraphQLSchema graphQLSchema = env.getGraphQLSchema();
        GraphQLObjectType containingGraphQLType = graphQLSchema.getObjectType(braidSchemaSource.getBraidTypeName(containingType.getName()));
        return newDataFetchingEnvironment(env)
                .source(env.getSource())
                .fieldDefinition(containingGraphQLType.getFieldDefinition(field.getName()))
                .mergedField(MergedField.newMergedField()
                        .addField(new Field(field.getName()))
                        .build())
                .fieldType(graphQLSchema.getObjectType(((TypeName) field.getType()).getName()))
                .parentType(graphQLSchema.getObjectType("Query"))
                .build();
    }
}
