package org.mule.weave.v2.grammar

import org.mule.weave.v2.parser.SafeStringBasedParserInput
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.CommentType
import org.mule.weave.v2.parser.ast.ContainerAstNode
import org.mule.weave.v2.parser.ast.header.HeaderNode
import org.mule.weave.v2.parser.ast.header.directives.DirectiveNode
import org.mule.weave.v2.parser.ast.module.ModuleNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.location.WeaveLocation
import org.parboiled2._

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class Grammar(val input: SafeStringBasedParserInput, val resourceName: NameIdentifier, val defaultErrorTrace: Int = Grammar.DEFAULT_ERROR_TRACE, val attachDocumentation: Boolean = true) extends Parser with MappingGrammar with ModuleGrammar {
  /**
    * Assign the comments to their corresponding positions
    */
  def assignCommentNodes(documentNode: AstNode): AstNode = {
    val lineToNodeCache = new LineToNodeCache(documentNode)

    for ((startIndex, comment) <- _comments) {

      val commentLocation = comment.location()
      val commentStartLine = commentLocation.startPosition.line
      val commentEndLine = commentLocation.endPosition.line
      val commentEndIndex = commentLocation.endPosition.index
      val commentStartIndex = commentLocation.startPosition.index

      if (commentLocation.startPosition.index == 0) {
        documentNode.addComment(comment)
      } else {
        val astNode = comment.commentType match {
          case CommentType.DocComment => lineToNodeCache.lookFirstOnNextLine(commentEndLine)
          case _                      => lineToNodeCache.lookNextOnSameLine(commentEndLine, commentEndIndex).orElse(lineToNodeCache.lookPreviousOnSameLine(commentStartLine, commentStartIndex)).orElse(lineToNodeCache.lookFirstOnNextLine(commentEndLine))
        }
        astNode.getOrElse(parentOf(commentLocation, documentNode)).addComment(comment)
      }
    }
    _comments.clear()
    documentNode
  }

  override def errorTraceCollectionLimit: Int = defaultErrorTrace

  def headerExpr: Rule1[HeaderNode] = rule {
    pushPosition ~ ((ws ~ optional(comments ~ eol)) ~ directives ~ ws ~> createHeaderNode2) ~ injectPosition
  }

  val createHeaderNode2 = (directives: Seq[DirectiveNode]) => {
    HeaderNode(directives)
  }

  def weavedocument: Rule1[AstNode] = namedRule("Weave Document") {
    pushPosition ~ &(setSyntaxVersion) ~ ((headerExpr ~ optional(headerSeparator ~ ws ~ content) | push(HeaderNode.withVersion()) ~ optional(content)) ~ ws ~ EOI ~> createRoot) ~ injectPosition
  }

  val createRoot: (HeaderNode, Option[AstNode]) => AstNode = (headerNode: HeaderNode, expr: Option[AstNode]) => {
    val result = if (expr.isDefined) {
      DocumentNode(headerNode, expr.get)
    } else {
      ModuleNode(resourceName, headerNode.directives)
    }
    if (attachDocumentation) {
      assignCommentNodes(result)
    }
    result
  }

  /**
    * Returns the most specific node that contains the given location
    */
  def parentOf(location: WeaveLocation, rootNode: AstNode): AstNode = {
    def contains(loc: WeaveLocation, node: AstNode) = {
      val nodeLoc = node.location()
      loc.startPosition.index >= nodeLoc.startPosition.index && loc.endPosition.index <= nodeLoc.endPosition.index
    }

    def refineParent(currentParent: AstNode): AstNode = {
      val maybeNode = currentParent.children().find(child => contains(location, child))
      if (maybeNode.isDefined) {
        refineParent(maybeNode.get)
      } else {
        currentParent
      }
    }

    refineParent(rootNode)
  }
}
object Grammar {
  def DEFAULT_ERROR_TRACE = 5
}

class LineToNodeCache(documentNode: AstNode) {

  //first element that starts in line
  val start = new mutable.TreeMap[Int, ArrayBuffer[AstNode]]()
  //last element that ends in line
  val end = new mutable.TreeMap[Int, ArrayBuffer[AstNode]]()

  process(documentNode)

  def lookFirstOnNextLine(endLine: Int): Option[AstNode] = {
    start.find((node) => node._1 > endLine).flatMap(_._2.headOption)
  }

  def lookPreviousOnSameLine(startLine: Int, startPosition: Int): Option[AstNode] = {
    end.get(startLine).flatMap((nodes) => nodes.reverse.find((node) => node.location().endPosition.index < startPosition))
  }

  def lookNextOnSameLine(endLine: Int, endPosition: Int): Option[AstNode] = {
    start.get(endLine).flatMap((nodes) => nodes.find((node) => node.location().startPosition.index > endPosition))
  }

  def process(node: AstNode): Unit = {
    doProcess(node)
  }

  def doProcess(node: AstNode): Unit = {
    val startPosition = node.location().startPosition
    val endPosition = node.location().endPosition
    val startLine = startPosition.line
    val endLine = endPosition.line
    if (!node.isInstanceOf[ContainerAstNode]) {
      start.get(startLine) match {
        case Some(foundNode) => {
          foundNode.+=(node)
        }
        case None => start.put(startLine, ArrayBuffer(node))
      }
    }

    val children = node.children()
    children.foreach(doProcess)
    if (!node.isInstanceOf[ContainerAstNode]) {
      end.get(endLine) match {
        case Some(foundNode) => {
          foundNode.+=(node)
        }
        case None => end.put(endLine, ArrayBuffer(node))
      }
    }
  }
}
