package org.mule.weave.v2.ts.resolvers

import org.mule.weave.v2.parser.ShouldTypeDynamicGenericFunction
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.WeaveLocationCapable
import org.mule.weave.v2.parser.ast.functions.FunctionNode
import org.mule.weave.v2.parser.ast.functions.FunctionParameter
import org.mule.weave.v2.parser.ast.functions.OverloadedFunctionNode
import org.mule.weave.v2.parser.ast.types.DynamicReturnTypeNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.scope.AstNavigator
import org.mule.weave.v2.ts.DynamicReturnType
import org.mule.weave.v2.ts.Edge
import org.mule.weave.v2.ts.EdgeLabels
import org.mule.weave.v2.ts.FunctionType
import org.mule.weave.v2.ts.FunctionTypeHelper
import org.mule.weave.v2.ts.FunctionTypeParameter
import org.mule.weave.v2.ts.ReferenceResolver
import org.mule.weave.v2.ts.ScopeGraphTypeReferenceResolver
import org.mule.weave.v2.ts.Substitution
import org.mule.weave.v2.ts.TypeNode
import org.mule.weave.v2.ts.TypeParameter
import org.mule.weave.v2.ts.WeaveType
import org.mule.weave.v2.ts.WeaveTypeReferenceResolver
import org.mule.weave.v2.ts.WeaveTypeResolutionContext
import org.mule.weave.v2.ts.WeaveTypeResolver

class FunctionTypeResolver(typeReferenceResolver: WeaveTypeReferenceResolver, referenceResolver: ReferenceResolver) extends WeaveTypeResolver {

  override def resolveExpectedType(node: TypeNode, incomingExpectedType: Option[WeaveType], ctx: WeaveTypeResolutionContext): Seq[(Edge, WeaveType)] = {
    val functionNode = getFunctionNode(node)
    val propagatedReturnType = incomingExpectedType match {
      case Some(ft: FunctionType) => Some(ft.returnType)
      case _                      => None
    }
    val expectedReturnType = functionNode.returnType
      .map((rt) => {
        WeaveType(rt, new ScopeGraphTypeReferenceResolver(ctx.currentScopeNavigator))
      })
      .orElse(propagatedReturnType)

    if (expectedReturnType.isDefined) {
      val mayBeGraph = ctx.getFunctionSubGraphs(functionNode).flatMap(_.headOption.map(_._2))
      mayBeGraph match {
        case Some(graph) => {
          val maybeNode = graph.findLocalNode(functionNode.body)
          maybeNode.foreach((node) => {
            node
              .outgoingEdges()
              .foreach((edge) => {
                edge.updateExpectedType(expectedReturnType.get)
              })
          })
          val propagator = ctx.newReverseExecutorWithContext(ctx.currentScopeNavigator, graph, ctx.currentParsingContext)
          propagator.run()
        }
        case _ =>
      }
    }
    Seq()
  }

  override def resolveReturnType(node: TypeNode, ctx: WeaveTypeResolutionContext): Option[WeaveType] = {
    val functionName: Option[String] = getFunctionName(node)
    val functionNode: FunctionNode = getFunctionNode(node)
    val functionParamsNodes: Seq[FunctionParameter] = functionNode.params.paramList

    val functionTypeParams: Seq[TypeParameter] = if (functionNode.typeParameterList.isDefined) {
      functionNode.typeParameterList.get.typeParameters.map(tp => {
        val resolver = ctx.currentScopeNavigator.rootScope.referenceResolver()
        resolver.resolveAbstractTypeParameter(tp).asInstanceOf[TypeParameter]
      })
    } else Seq()

    val concreteToAbstractTypeParameters: Substitution = buildTypeParamSubstitution(ctx, functionNode)
    val concreteParameters: Seq[FunctionTypeParameter] =
      functionParamsNodes
        .filterNot(_.variable.name == NameIdentifier.INSERTED_FAKE_VARIABLE_NAME)
        .zipWithIndex
        .map((arg) => {
          val parameterName: String = arg._1.variable.name
          val maybeParameterType: Option[WeaveType] = node.incomingType(EdgeLabels.PARAM_TYPE(parameterName))
          val parameterType: WeaveType = maybeParameterType.getOrElse(FunctionTypeHelper.createDynamicParameter(arg._2))
          val defaultValueType: Option[WeaveType] = node.incomingType(EdgeLabels.DEFAULT_VALUE(parameterName))
          FunctionTypeParameter(parameterName, parameterType, arg._1.defaultValue.isDefined, defaultValueType)
        })

    val abstractParameters: Seq[FunctionTypeParameter] = concreteParameters.map((fp) => {
      fp.copy(wtype = concreteToAbstractTypeParameters.apply(ctx, fp.wtype))
    })

    val expectedReturnType: Option[WeaveType] = node.incomingType(EdgeLabels.RETURN_TYPE)
    val abstractReturnType: Option[WeaveType] =
      expectedReturnType.map((rt) => {
        concreteToAbstractTypeParameters.apply(ctx, rt)
      })

    val returnType: WeaveType =
      if (functionParamsNodes.exists(_.wtype.isEmpty)) {
        createDynamicReturnType(node, ctx, functionName, functionNode, abstractParameters, abstractReturnType)
      } else {
        functionNode.returnType
          .map({
            case _: DynamicReturnTypeNode => {
              createDynamicReturnType(node, ctx, functionName, functionNode, abstractParameters, None)
            }
            case _ => {
              //Validate body matches function return type
              val drt: DynamicReturnType = createDynamicReturnType(node, ctx, functionName, functionNode, abstractParameters, abstractReturnType)
              // The expected one should be the concrete one in this case as we are validating the body of the function
              //
              validateFunctionBody(ctx, drt, functionParamsNodes, expectedReturnType)
              // The type of the function should be the abstract type
              abstractReturnType.get
            }
          })
          .getOrElse({
            //Try to infer return type from body so we simulate the body using the
            val drt: DynamicReturnType = createDynamicReturnType(node, ctx, functionName, functionNode, abstractParameters, None)
            val maybeInferredReturnType: Option[WeaveType] = validateFunctionBody(ctx, drt, functionParamsNodes, None)
            maybeInferredReturnType
              .map((concreteType) => {
                val weaveType: WeaveType = concreteToAbstractTypeParameters.apply(ctx, concreteType)
                weaveType
              })
              .getOrElse(drt)
          })
      }

    /**
      * Type parameters in a function definition are always abstract as they are just place holders.
      * That is why all reference to type parameters (Concrete Type parameters) need to be converted to Abstract Type Parameters
      */
    val functionType: FunctionType = FunctionType(
      functionTypeParams,
      abstractParameters,
      returnType,
      name = functionName,
      customReturnTypeResolver = functionName.flatMap((name) => CustomFunctionTypeResolver.findRegisteredCustomTypeResolver(name, concreteParameters, ctx)))

    if (FunctionTypeHelper.hasDynamicReturn(functionType) && functionTypeParams.nonEmpty && functionNode.returnType.isEmpty) {
      // We try to report the warning on the function name identifier because it looks cleaner
      val msg = ShouldTypeDynamicGenericFunction(functionType)
      val reportedNodeLocation = getFunctionNameIdentifier(node).getOrElse(node).asInstanceOf[WeaveLocationCapable].location()
      ctx.unsafeLocationWarning(msg, node, reportedNodeLocation)
    }

    //Replace any concrete type param declared in this function with an abstract type param

    val navigator: AstNavigator = ctx.currentScopeNavigator.rootScope.astNavigator()
    //We set the documentation from the node if specified if not from the parent (VariableDirective/FunctionDirective)
    if (node.astNode.hasWeaveDoc) {
      setDocumentationAnnotation(node.astNode, functionType)
    } else {
      val maybeNode = navigator.parentOf(node.astNode)
      maybeNode match {
        case Some(fdn) => {
          setDocumentationAnnotation(fdn, functionType)
        }
        case _ =>
      }
    }
    Some(functionType)
  }

  /**
    * Builds a substitution map from concrete type params declared in this function to concrete type parameters.
    * This this is required as a function definition type needs to use abstract types
    */
  private def buildTypeParamSubstitution(ctx: WeaveTypeResolutionContext, functionNode: FunctionNode) = {
    val typeReferenceResolver: ScopeGraphTypeReferenceResolver = ctx.currentScopeNavigator.rootScope.referenceResolver()
    val concreteToAbstract: Seq[(TypeParameter, TypeParameter)] = functionNode.typeParameterList match {
      case Some(typeParams) => {
        typeParams.typeParameters
          .map((tp) => {
            //Load the abstract type parameter
            typeReferenceResolver.resolveAbstractTypeParameter(tp)
          })
          .collect({
            case abstractTypeParameter: TypeParameter => {
              (typeReferenceResolver.getConcreteTypeOf(abstractTypeParameter), abstractTypeParameter)
            }
          })
      }
      case None => Seq()
    }
    Substitution(concreteToAbstract)
  }

  private def createDynamicReturnType(node: TypeNode, ctx: WeaveTypeResolutionContext, mayBeName: Option[String], functionNode: FunctionNode, arguments: Seq[FunctionTypeParameter], expectedReturnType: Option[WeaveType]) = {
    DynamicReturnType(arguments, functionNode, node.parentGraph, ctx.currentScopeNavigator, mayBeName, expectedReturnType, referenceResolver)
  }

  private def getFunctionNode(node: TypeNode) = {
    node.astNode.asInstanceOf[FunctionNode]
  }

  private def setDocumentationAnnotation(fdn: AstNode, functionType: WeaveType): Unit = {
    fdn.weaveDoc.foreach((doc) => {
      functionType.withDocumentation(doc.literalValue, doc.location())
    })
  }

  private def validateFunctionBody(ctx: WeaveTypeResolutionContext, drt: DynamicReturnType, paramList: Seq[FunctionParameter], expectedReturnType: Option[WeaveType]): Option[WeaveType] = {
    //We use the function parameters as arguments to the function
    //This will replace any abstract type with the concrete type
    val functionArguments: Seq[WeaveType] = paramList.map((param) => WeaveType(param.wtype.get, typeReferenceResolver))
    FunctionTypeHelper.resolveReturnType(functionArguments, expectedReturnType, ctx, drt, strict = true)
  }

  private def getFunctionNameIdentifier(node: TypeNode): Option[NameIdentifier] = {
    val outgoingNodes: Seq[AstNode] = node.outgoingEdges().map(_.target.astNode)
    val mayBeName: Option[NameIdentifier] = outgoingNodes.collectFirst({ case n: NameIdentifier => n })
    mayBeName
      .orElse({
        node
          .outgoingEdges()
          .map(_.target)
          .collectFirst({
            case node: TypeNode if node.astNode.isInstanceOf[OverloadedFunctionNode] => getFunctionNameIdentifier(node)
          })
          .flatten
      })
    mayBeName
  }
  def getFunctionName(node: TypeNode): Option[String] = getFunctionNameIdentifier(node).map(_.name)
}
