package com.atlassian.braid.transformation;

import com.atlassian.braid.FieldKey;
import com.atlassian.braid.FieldTransformation;
import com.atlassian.braid.FieldTransformationContext;
import com.atlassian.braid.Link;
import com.atlassian.braid.LinkUtils.ResolvedArgument;
import com.atlassian.braid.java.util.BraidFutures;
import com.atlassian.braid.source.RelativeGraphQLError;
import graphql.GraphQLError;
import graphql.execution.DataFetcherResult;
import graphql.execution.MergedField;
import graphql.language.Argument;
import graphql.language.ArrayValue;
import graphql.language.Field;
import graphql.language.ObjectField;
import graphql.language.ObjectValue;
import graphql.language.OperationDefinition;
import graphql.language.Selection;
import graphql.language.Value;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.DataFetchingEnvironmentImpl;
import org.dataloader.DataLoader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.stream.Collectors;

import static com.atlassian.braid.ArgumentValueProvider.staticArgumentValue;
import static com.atlassian.braid.BatchLoaderUtils.getTargetIdsFromEnvironment;
import static com.atlassian.braid.LinkUtils.resolveArgumentsForLink;
import static com.atlassian.braid.transformation.QueryTransformationUtils.addFieldToQuery;
import static com.atlassian.braid.transformation.QueryTransformationUtils.cloneTrimAndAliasField;
import static java.util.Collections.singletonList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;


/**
 * A field transformation that processes a link to a target data source and generates fields for to fetch from that source
 */
public class LinkTransformation implements FieldTransformation {

    private static final Logger log = LoggerFactory.getLogger(LinkTransformation.class);
    private final Link link;
    private final BatchMapping batchMapping;

    LinkTransformation(Link link, BatchMapping batchMapping) {
        this.link = requireNonNull(link);
        this.batchMapping = batchMapping;
    }

    public Link getLink() {
        return link;
    }

    public BatchMapping getBatchMapping() {
        return batchMapping;
    }

    /**
     * In most cases, the query to the target has a single top-level field. However, if the link's from field is of
     * list type, then the query to the link target contains a list of top-level fields.
     */
    @Override
    public CompletableFuture<List<Field>> apply(DataFetchingEnvironment environment, FieldTransformationContext context) {
        Field cloneOfCurrentField = environment.getField().deepCopy();
        List<Selection> selections = cloneOfCurrentField.getSelectionSet().getSelections();
        Set<String> shortCircuitFields = selectFieldsForShortCircuit(selections, link);
        if (link.isSimpleLink()) {
            return getTargetIdsFromEnvironment(link.getSourceInputFieldName(), environment)
                    .thenCompose(targetIds -> {
                        @SuppressWarnings("unchecked") CompletableFuture<Field>[] futureFields = targetIds.stream()
                                .map(targetId -> transformSimpleLink(targetId, context, environment, shortCircuitFields))
                                .toArray(CompletableFuture[]::new);
                        return BraidFutures.all(toList(), futureFields);
                    });
        } else {
            return transformComplexLink(context, environment, shortCircuitFields);
        }
    }

    private CompletableFuture<Field> transformSimpleLink(Object argumentValue,
                                                         FieldTransformationContext context,
                                                         DataFetchingEnvironment environment,
                                                         Set<String> shortCircuitFields) {
        final FieldWithCounter field = cloneTrimAndAliasField(
                context,
                new ArrayList<>(),
                environment,
                true);
        CompletableFuture<List<ResolvedArgument>> args = resolveArgumentsForLink(this.link, context.getSchemaSource(),
                environment, staticArgumentValue(argumentValue), field.counter);
        return args.thenApply(resolvedArguments ->
                createFieldForSelection(field, environment, context, shortCircuitFields, resolvedArguments));
    }

    private CompletableFuture<List<Field>> transformComplexLink(FieldTransformationContext context,
                                                                DataFetchingEnvironment environment,
                                                                Set<String> shortCircuitFields) {
        final FieldWithCounter field = cloneTrimAndAliasField(
                context,
                new ArrayList<>(),
                environment,
                true);
        CompletableFuture<List<ResolvedArgument>> args = resolveArgumentsForLink(this.link, context.getSchemaSource(),
                environment, link.getArgumentValueProvider(), field.counter);
        return args.thenApply(resolvedArguments ->
                singletonList(createFieldForSelection(field, environment, context, shortCircuitFields, resolvedArguments)));
    }

    private Field createFieldForSelection(FieldWithCounter field,
                                          DataFetchingEnvironment environment,
                                          FieldTransformationContext context,
                                          Set<String> shortCircuitFields,
                                          List<ResolvedArgument> resolvedArguments) {

        if (!areAllArgumentsValuesAllowed(resolvedArguments)) {
            context.getShortCircuitedData().put(new FieldKey(field.field.getAlias()), null);
        } else if (shortCircuitFields != null) {
            Map<String, Object> result = resolvedArguments.stream()
                    .filter(arg -> shortCircuitFields.contains(arg.getLinkArgument().getTargetFieldMatchingArgument()))
                    .collect(Collectors.toMap(arg -> arg.getLinkArgument().getTargetFieldMatchingArgument(), ResolvedArgument::getValue));
            context.getShortCircuitedData().put(new FieldKey(field.field.getAlias()), result);
        } else {
            OperationDefinition operationDefinition = environment.getOperationDefinition();
            createQueryField(
                    context,
                    field,
                    resolvedArguments);
            addFieldToQuery(context, environment, operationDefinition, field);
        }
        return field.field;
    }

    private void createQueryField(FieldTransformationContext fieldTransformationContext,
                                  FieldWithCounter field,
                                  List<ResolvedArgument> resolvedArguments) {
        if (link.getCustomTransformation() != null) {
            link.getCustomTransformation().createQuery(field.field, null);
            return;
        }
        field.field = field.field.transform(
                builder -> builder.name(link.getTopLevelQueryField())
        );

        List<Argument> fieldArguments = new ArrayList<>(resolvedArguments.size());
        for (ResolvedArgument resolvedArgument : resolvedArguments) {
            fieldTransformationContext.addVariableDefinition(resolvedArgument.getVariableDefinition());
            fieldArguments.add(resolvedArgument.getArgument());
            fieldTransformationContext.getVariables()
                    .put(resolvedArgument.getVariableDefinition().getName(), resolvedArgument.getValue());
        }

        field.field = field.field.transform(
                builder -> builder.arguments(fieldArguments)
        );
    }


    private static boolean areAllArgumentsValuesAllowed(List<ResolvedArgument> resolvedArguments) {
        return resolvedArguments.stream()
                .noneMatch(arg -> arg.getValue() == null && !arg.getLinkArgument().isNullable());
    }

    /**
     * Gets list of fields that can be short-circuited or null if selection contains at least one field that cannot be
     * short-circuited.
     */
    private static Set<String> selectFieldsForShortCircuit(List<Selection> selections, Link link) {
        Set<String> selectionFields = new HashSet<>();
        for (Selection selection : selections) {
            if (!(selection instanceof Field)) {
                return null; // this means that any fragment will make this return false
            }

            String fieldName = ((Field) selection).getName();
            if (!link.isFieldMatchingArgument(fieldName)) {
                return null;
            }
            selectionFields.add(fieldName);
        }
        return selectionFields;
    }

    @Override
    public DataFetcherResult<Object> unapply(DataFetchingEnvironment environment, DataFetcherResult<Object> dataFetcherResult) {
        if (link.getCustomTransformation() != null) {
            return link.getCustomTransformation().unapplyForResult(environment.getField(), dataFetcherResult);
        }
        return dataFetcherResult;
    }

    /**
     * This batches multiple queries with same params by converting them to a single query with a list of argument values and then executes the single query.
     * For complete example, refer to com/atlassian/braid/testBraidBatchAllWithReactionsCase.yml
     */
    public CompletionStage<List<DataFetcherResult<Object>>> batchQueries(
            List<DataFetchingEnvironment> environments,
            FieldTransformationContext fieldTransformationContext) {
        DataFetchingEnvironment env = environments.get(0);
        String dataLoaderKey = DataFetcherUtils.getDataLoaderKey("Query", batchMapping.batchField);

        OperationDefinition queryOp = fieldTransformationContext.getOperation();
        DataFetchingEnvironment batchEnv = getDataFetchingEnvironment(queryOp, fieldTransformationContext, env);

        DataLoader<DataFetchingEnvironment, DataFetcherResult<Object>> batchDataLoader =
                env.getDataLoader(dataLoaderKey);
        CompletableFuture<List<DataFetcherResult<Object>>> results = batchDataLoader.load(batchEnv)
                .thenApply(listResult -> {
                    List<Object> dataList = (List<Object>) listResult.getData();
                    // Errors should be carried over.
                    List<GraphQLError> errors = listResult.getErrors();
                    if (dataList == null) {
                        return environments.stream()
                                .map(environment -> DataFetcherResult.newResult()
                                        .data(null)
                                        .errors(processRelativeGraphQLError(errors, environment))
                                        .localContext(listResult.getLocalContext())
                                        .build()
                                )
                                .collect(toList());
                    }
                    return dataList.stream()
                            .map(data -> DataFetcherResult.newResult()
                                    .data(data)
                                    .errors(errors)
                                    .localContext(listResult.getLocalContext())
                                    .build()
                            )
                            .collect(toList());
                });
        batchDataLoader.dispatch();
        return results;
    }

    /**
     * This builds data fetching environment for batchField's dataFetcher by setting executionContext, fieldDefinition, fields and arguments.
     */
    private DataFetchingEnvironment getDataFetchingEnvironment(
            OperationDefinition queryOp,
            FieldTransformationContext fieldTransformationContext,
            DataFetchingEnvironment env) {
        // Batch the arguments into a list
        List<Argument> batchArgs = new ArrayList<>();
        /*
        batchArgValue: Required for building batchField. Gives list of Object Values which is a list of Object Fields.
            Example of batchArgValue:
               [ObjectValue{objectFields=[ObjectField{name='id', value=VariableReference{name='id100'}}, ObjectField{name='comment', value=VariableReference{name='comment100'}}]},
               ObjectValue{objectFields=[ObjectField{name='id', value=VariableReference{name='id101'}}, ObjectField{name='comment', value=VariableReference{name='comment101'}}]}]
         */
        ArrayValue batchArgValue = ArrayValue.newArrayValue().values(
                queryOp.getSelectionSet().getSelections().stream()
                        .filter(selection -> selection instanceof Field)
                        .map(selection -> (Field) selection)
                        .map(field -> field.getArguments().stream().map(argument -> new ObjectField(argument.getName(), argument.getValue()))
                                .collect(toList()))
                        .map(ObjectValue::new).collect(toList())
        ).build();
        batchArgs.add(Argument.newArgument().name(batchMapping.batchArgName).value(batchArgValue).build());

        Field batchField = env.getField().transform(builder -> builder
                .name(batchMapping.batchField)
                .arguments(batchArgs)
        );

        /*
        argsList: Required for setting arguments when building DataFetchingEnvironment. Gives list of map of argName and argValue.
            Example of argsList:
               [{"comment" -> "VariableReference{name='comment100'}", "id" -> "VariableReference{name='id100'}"},
               "comment" -> "VariableReference{name='comment101'}", "id" -> "VariableReference{name='id101'}"]
         */
        List<Map<String, Value>> argsList = queryOp.getSelectionSet().getSelections().stream()
                .filter(selection -> selection instanceof Field)
                .map(selection -> (Field) selection)
                .map(field -> field.getArguments()
                        .stream()
                        .collect(Collectors.toMap(Argument::getName, Argument::getValue)))
                .collect(toList());

        Map<String, Object> arguments = new HashMap<>();
        arguments.put(batchMapping.batchArgName, argsList);

        return DataFetchingEnvironmentImpl.newDataFetchingEnvironment(env)
                .variables(fieldTransformationContext.getVariables())
                .operationDefinition(queryOp)
                .fieldDefinition(env.getFieldDefinition())
                .mergedField(MergedField.newMergedField()
                        .addField(batchField)
                        .build())
                .arguments(arguments)
                .build();
    }

    private static List<GraphQLError> processRelativeGraphQLError(List<GraphQLError> errors, DataFetchingEnvironment environment) {
        return errors.stream()
                .map(error -> error instanceof RelativeGraphQLError
                        ? ((RelativeGraphQLError) error).updateBasePath(environment.getExecutionStepInfo().getPath())
                        : error
                )
                .collect(toList());
    }
}

