/*
 * Decompiled with CFR 0.152.
 */
package com.atlassian.braid.transformation;

import com.atlassian.braid.BatchLoaderEnvironment;
import com.atlassian.braid.Link;
import com.atlassian.braid.LinkArgument;
import com.atlassian.braid.SchemaNamespace;
import com.atlassian.braid.SchemaSource;
import com.atlassian.braid.TypeUtils;
import com.atlassian.braid.transformation.BatchMapping;
import com.atlassian.braid.transformation.BatchUtils;
import com.atlassian.braid.transformation.BraidSchemaSource;
import com.atlassian.braid.transformation.BraidingContext;
import com.atlassian.braid.transformation.DataFetcherUtils;
import com.atlassian.braid.transformation.LinkTransformation;
import com.atlassian.braid.transformation.SchemaTransformation;
import graphql.execution.DataFetcherResult;
import graphql.language.FieldDefinition;
import graphql.language.InputValueDefinition;
import graphql.language.ListType;
import graphql.language.NonNullType;
import graphql.language.ObjectTypeDefinition;
import graphql.language.SDLDefinition;
import graphql.language.Type;
import graphql.language.TypeDefinition;
import graphql.parser.Parser;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.idl.TypeDefinitionRegistry;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.dataloader.BatchLoader;
import org.dataloader.DataLoader;

public class LinkSchemaTransformation
implements SchemaTransformation {
    @Override
    public Map<String, BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>>> transform(BraidingContext braidingContext) {
        Map<SchemaNamespace, BraidSchemaSource> sources = braidingContext.getDataSources();
        ObjectTypeDefinition queryObjectTypeDefinition = braidingContext.getQueryObjectTypeDefinition();
        ObjectTypeDefinition mutationObjectTypeDefinition = braidingContext.getMutationObjectTypeDefinition();
        BatchLoaderEnvironment batchLoaderEnvironment = braidingContext.getBatchLoaderEnvironment();
        TypeDefinitionRegistry braidTypeRegistry = braidingContext.getRegistry();
        HashMap<String, BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>>> batchLoaders = new HashMap<String, BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>>>();
        for (BraidSchemaSource source : sources.values()) {
            TypeDefinitionRegistry sourceTypeRegistry = source.getTypeRegistry();
            SchemaSource sourceSchemaSource = source.getSchemaSource();
            TypeDefinitionRegistry privateTypes = sourceSchemaSource.getPrivateSchema();
            for (Link link : sourceSchemaSource.getLinks()) {
                HashMap<String, TypeDefinition> dsTypes = new HashMap<String, TypeDefinition>(braidTypeRegistry.types());
                ObjectTypeDefinition braidObjectTypeDefinition = LinkSchemaTransformation.getObjectTypeDefinition(queryObjectTypeDefinition, mutationObjectTypeDefinition, braidTypeRegistry, dsTypes, source.getLinkBraidSourceType(link));
                if (braidObjectTypeDefinition.equals(TypeUtils.findQueryType(braidTypeRegistry).orElse(null))) {
                    braidObjectTypeDefinition = TypeUtils.findQueryType(sourceTypeRegistry).get();
                }
                if (braidObjectTypeDefinition.equals(TypeUtils.findMutationType(braidTypeRegistry).orElse(null))) {
                    braidObjectTypeDefinition = TypeUtils.findMutationType(sourceTypeRegistry).get();
                }
                LinkSchemaTransformation.validateSourceFromFieldExists(source, link, privateTypes);
                BraidSchemaSource targetSource = sources.get(link.getTargetNamespace());
                if (targetSource == null) {
                    throw new IllegalArgumentException("Can't find target schema source: " + link.getTargetNamespace());
                }
                if (!targetSource.hasType(TypeUtils.unwrap(Parser.parseType((String)link.getTargetType())))) {
                    throw new IllegalArgumentException("Can't find target type: " + link.getTargetType());
                }
                FieldDefinition topLevelField = LinkSchemaTransformation.topLevelFieldForLink(link, targetSource);
                BatchMapping batchMapping = BatchUtils.getBatchMapping(topLevelField);
                if (!link.isNoSchemaChangeNeeded()) {
                    List<FieldDefinition> fieldDefinitions = LinkSchemaTransformation.modifySchema(link, braidObjectTypeDefinition, topLevelField);
                    if (batchMapping != null) {
                        FieldDefinition batchFieldDef = (FieldDefinition)TypeUtils.findQueryType(targetSource.getSchemaSource().getPrivateSchema()).flatMap(queryTypeDef -> queryTypeDef.getFieldDefinitions().stream().filter(fieldDefinition -> fieldDefinition.getName().equals(batchMapping.batchField)).findFirst()).orElseThrow(() -> new IllegalStateException("Could not find query field: " + batchMapping.batchField));
                        fieldDefinitions.add(batchFieldDef);
                    }
                    ObjectTypeDefinition newBraidObjectTypeDefinition = braidObjectTypeDefinition.transform(builder -> builder.fieldDefinitions(fieldDefinitions));
                    if (braidObjectTypeDefinition.equals(TypeUtils.findQueryType(sourceTypeRegistry).orElse(null)) || braidObjectTypeDefinition.equals(TypeUtils.findMutationType(sourceTypeRegistry).orElse(null))) {
                        sourceTypeRegistry.remove((SDLDefinition)braidObjectTypeDefinition);
                        sourceTypeRegistry.add((SDLDefinition)newBraidObjectTypeDefinition);
                    } else {
                        braidTypeRegistry.remove((SDLDefinition)braidObjectTypeDefinition);
                        braidTypeRegistry.add((SDLDefinition)newBraidObjectTypeDefinition);
                    }
                }
                String type = source.getLinkBraidSourceType(link);
                String field = link.getNewFieldName();
                String linkDataLoaderKey = DataFetcherUtils.getLinkDataLoaderKey(type, field);
                DataFetcher dataFetcher = env -> {
                    DataLoader dataLoader = env.getDataLoader(linkDataLoaderKey);
                    return dataLoader.load((Object)env);
                };
                braidingContext.registerDataFetcher(type, field, dataFetcher);
                SchemaSource targetSchemaSource = targetSource.getSchemaSource();
                BatchLoader<DataFetchingEnvironment, DataFetcherResult<Object>> batchLoader = targetSchemaSource.newBatchLoader(targetSchemaSource, new LinkTransformation(link, batchMapping), batchLoaderEnvironment);
                batchLoaders.put(linkDataLoaderKey, batchLoader);
            }
        }
        return batchLoaders;
    }

    private static FieldDefinition topLevelFieldForLink(Link link, BraidSchemaSource targetSource) {
        return (FieldDefinition)TypeUtils.findQueryType(targetSource.getSchemaSource().getPrivateSchema()).flatMap(queryType -> queryType.getFieldDefinitions().stream().filter(fieldDefinitiuon -> link.getTopLevelQueryField().equals(fieldDefinitiuon.getName())).findFirst()).orElseThrow(() -> new IllegalStateException(String.format("Cannot find top level query field '%s' in source '%s' for link on field '%s' defined in '%s'", link.getTopLevelQueryField(), link.getTargetNamespace(), link.getNewFieldName(), link.getSourceNamespace())));
    }

    private static List<FieldDefinition> modifySchema(Link link, ObjectTypeDefinition typeDefinition, FieldDefinition topLevelField) {
        ArrayList<FieldDefinition> fieldDefinitions = new ArrayList<FieldDefinition>(typeDefinition.getFieldDefinitions());
        FieldDefinition newField = fieldDefinitions.stream().filter(d -> d.getName().equals(link.getNewFieldName())).findFirst().orElse(null);
        Map<String, FieldDefinition> objectFields = fieldDefinitions.stream().filter(Objects::nonNull).collect(Collectors.toMap(FieldDefinition::getName, Function.identity()));
        link.getLinkArguments().stream().filter(linkArgument -> linkArgument.getArgumentSource() == LinkArgument.ArgumentSource.OBJECT_FIELD && linkArgument.isRemoveInputField()).map(LinkArgument::getSourceName).forEach(fieldToRemove -> Optional.ofNullable((FieldDefinition)objectFields.get(fieldToRemove)).ifPresent(fieldDefinitions::remove));
        Object targetType = Parser.parseType((String)link.getTargetType());
        Object object = targetType = link.targetNonNullable() ? NonNullType.newNonNullType((Type)targetType).build() : targetType;
        if (newField == null) {
            targetType = LinkSchemaTransformation.adjustTypeForSimpleLink(link, objectFields, targetType);
            List inputValueDefs = link.getLinkArguments().stream().filter(linkArgument -> linkArgument.getArgumentSource() == LinkArgument.ArgumentSource.FIELD_ARGUMENT).flatMap(linkArgument -> LinkSchemaTransformation.buildInputValueDefinitionForLink(topLevelField, linkArgument)).collect(Collectors.toList());
            newField = FieldDefinition.newFieldDefinition().name(link.getNewFieldName()).type(targetType).inputValueDefinitions(inputValueDefs).build();
            fieldDefinitions.add(newField);
        } else {
            if (LinkSchemaTransformation.isListType(newField.getType())) {
                targetType = newField.getType() instanceof NonNullType ? new NonNullType((Type)new ListType(targetType)) : new ListType(targetType);
            }
            fieldDefinitions.remove(newField);
            Type finalTargetType = targetType;
            fieldDefinitions.add(newField.transform(builder -> builder.type(finalTargetType)));
        }
        return fieldDefinitions;
    }

    private static Stream<InputValueDefinition> buildInputValueDefinitionForLink(FieldDefinition topLevelField, LinkArgument linkArgument) {
        return topLevelField.getInputValueDefinitions().stream().filter(input -> linkArgument.getQueryArgumentName().equals(input.getName())).findFirst().map(input -> Stream.of(InputValueDefinition.newInputValueDefinition().name(linkArgument.getSourceName()).type(input.getType()).build())).orElse(Stream.empty());
    }

    private static Type adjustTypeForSimpleLink(Link link, Map<String, FieldDefinition> objectFields, Type targetType) {
        Optional<FieldDefinition> sourceInputField;
        if (link.isSimpleLink() && (sourceInputField = Optional.ofNullable(objectFields.get(link.getSourceInputFieldName()))).isPresent() && LinkSchemaTransformation.isListType(sourceInputField.get().getType())) {
            targetType = new ListType(targetType);
        }
        return targetType;
    }

    private static ObjectTypeDefinition getObjectTypeDefinition(ObjectTypeDefinition queryObjectTypeDefinition, ObjectTypeDefinition mutationObjectTypeDefinition, TypeDefinitionRegistry braidTypeRegistry, Map<String, TypeDefinition> dsTypes, String linkSourceType) {
        ObjectTypeDefinition typeDefinition = (ObjectTypeDefinition)dsTypes.get(linkSourceType);
        if (typeDefinition == null && linkSourceType.equals(queryObjectTypeDefinition.getName()) && (typeDefinition = (ObjectTypeDefinition)TypeUtils.findQueryType(braidTypeRegistry).orElse(null)) == null && linkSourceType.equals(mutationObjectTypeDefinition.getName())) {
            typeDefinition = TypeUtils.findMutationType(braidTypeRegistry).orElse(null);
        }
        if (typeDefinition == null) {
            throw new IllegalArgumentException("Can't find source type: " + linkSourceType);
        }
        return typeDefinition;
    }

    private static void validateSourceFromFieldExists(BraidSchemaSource source, Link link, TypeDefinitionRegistry privateTypeDefinitionRegistry) {
        String sourceType = source.getSourceTypeName(link.getSourceType());
        ObjectTypeDefinition typeDefinition = (ObjectTypeDefinition)privateTypeDefinitionRegistry.getType(sourceType, ObjectTypeDefinition.class).orElseThrow(() -> new IllegalArgumentException(String.format("Can't find source type '%s' in private schema for link %s", sourceType, link.getNewFieldName())));
        Map fieldsByName = typeDefinition.getFieldDefinitions().stream().collect(Collectors.toMap(FieldDefinition::getName, Function.identity()));
        List missingSourceObjectFields = link.getLinkArguments().stream().filter(linkArgument -> linkArgument.getArgumentSource() == LinkArgument.ArgumentSource.OBJECT_FIELD).filter(linkArgument -> !fieldsByName.containsKey(linkArgument.getSourceName())).map(LinkArgument::getSourceName).collect(Collectors.toList());
        if (!missingSourceObjectFields.isEmpty()) {
            String missingFieldsStr = missingSourceObjectFields.stream().collect(Collectors.joining(", "));
            throw new IllegalArgumentException("Can't find source from field: " + missingFieldsStr);
        }
    }

    private static boolean isListType(Type type) {
        return type instanceof ListType || type instanceof NonNullType && ((NonNullType)type).getType() instanceof ListType;
    }
}

