package org.mule.weave.v2.parser

import org.mule.weave.v2.grammar.Grammar
import org.mule.weave.v2.grammar.Tokens
import org.mule.weave.v2.inspector.Inspector
import org.mule.weave.v2.inspector.NoInspector
import org.mule.weave.v2.inspector.ScopeCodeInspectorPhase
import org.mule.weave.v2.parser.MappingParser.parsingValidations
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.module.ModuleNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.phase.AstNodeResultAware
import org.mule.weave.v2.parser.phase.AstNodeVerificationPhase
import org.mule.weave.v2.parser.phase.CompilationPhase
import org.mule.weave.v2.parser.phase.FailureResult
import org.mule.weave.v2.parser.phase.ParsingAnnotationProcessorPhase
import org.mule.weave.v2.parser.phase.ParsingContentAware
import org.mule.weave.v2.parser.phase.ParsingContentInput
import org.mule.weave.v2.parser.phase.ParsingContext
import org.mule.weave.v2.parser.phase.ParsingPhase
import org.mule.weave.v2.parser.phase.ParsingResult
import org.mule.weave.v2.parser.phase.PhaseResult
import org.mule.weave.v2.parser.phase.ReverseTypeCheckingPhase
import org.mule.weave.v2.parser.phase.ScopeGraphResult
import org.mule.weave.v2.parser.phase.TypeCheckingResult
import org.mule.weave.v2.sdk.WeaveResource
import org.parboiled2.ParseError
import org.parboiled2.Parser.DeliveryScheme.Either

import java.lang.Character.isWhitespace
import scala.collection.mutable.ArrayBuffer

/**
  * Parses a module or a mapping file
  */
class DocumentParser(errorTrace: Int = 2, inspector: Inspector = NoInspector) {

  def parseWithRecovery(content: WeaveResource, context: ParsingContext, maybeLocation: Option[Int] = None): PhaseResult[ParsingResult[AstNode]] = {
    val parser = this

    var parseResult: PhaseResult[ParsingResult[AstNode]] = parser.parse(content, context.withMessageCollector(new MessageCollector))
    if (!parseResult.hasResult()) {
      if (maybeLocation.isDefined) {
        parseResult = tryRecoverFromError(
          content,
          context,
          maybeLocation.get,
          0,
          (location: Int, document: TextDocument, _: Int) => {
            // TODO should this count be added to the Fake Variable?
            document.insert(location, NameIdentifier.INSERTED_FAKE_VARIABLE_NAME)
          })
        if (!parseResult.hasResult()) {
          val location = maybeLocation.getOrElse(parseResult.errorMessages().head._1.endPosition.index)
          parseResult = tryRecoverFromError(content, context, location, 0, sanitizeDocument)
        }
      }

    }
    parseResult

  }

  private def tryRecoverFromError(content: WeaveResource, context: ParsingContext, cursorLocation: Int, count: Int = 0, sanitizer: (Int, TextDocument, Int) => Unit): PhaseResult[ParsingResult[AstNode]] = {
    val parser = this
    val parsingContext = context.withMessageCollector(new MessageCollector)
    val document = new TextDocument(content.content())
    sanitizer(cursorLocation, document, count)
    val documentString = document.text()
    val modifiedContent = WeaveResource(content.url(), documentString)
    val result = parser.parse(modifiedContent, parsingContext)
    if (result.hasResult() || count == 1) {
      result
    } else {
      tryRecoverFromError(modifiedContent, context, result.errorMessages().head._1.endPosition.index, count + 1, sanitizer)
    }
  }

  private def tokenize(text: String, elements: Seq[Char]): Seq[Token] = {
    val result: ArrayBuffer[Token] = ArrayBuffer()
    var index = 0
    var tokenStart = 0
    var currentWord = ""
    while (index < text.length) {
      val c = text.charAt(index)
      if (elements.contains(c) && currentWord.trim.nonEmpty) {
        result.+=(Token(currentWord.trim, tokenStart, index - 1))
        if (!c.isWhitespace) {
          result.+=(Token(c + "", index - 1, index))
        }
        currentWord = ""
        tokenStart = index
      } else {
        currentWord = currentWord + c
      }
      index = index + 1
    }

    if (currentWord.trim.nonEmpty) {
      result.+=(Token(currentWord.trim, tokenStart, index - 1))
    }
    result
  }

  private def sanitizeDocument(location: Int, document: TextDocument, retryCount: Int): Unit = {

    val line = document.textUntil(location)
    val tokenizer = tokenize(line, DocumentParser.TOKENS)
    val tokens = tokenizer.reverse //look from bottom to top
    val context = tokens.find((token) => DocumentParser.TOKEN_DIRECTIVE.contains(token.text))
    context.map(_.text) match {
      case Some(Tokens.DOCUMENT_SEPARATOR | Tokens.VAR) => {
        fixBodyExpression(location, document, retryCount)
      }
      case Some(Tokens.INPUT) | Some(Tokens.OUTPUT) => {
        tokens.head.text match {
          case "," => document.insert(location, DocumentParser.FAKE_VARIABLE_NAME + " = true")
          case "=" => document.insert(location, "true")
          case _ => {
            tokens.head.text match {
              case Tokens.INPUT | Tokens.OUTPUT => document.insert(location, DocumentParser.FAKE_VARIABLE_NAME)
              case _                            => document.insert(location, " = true")
            }
            //We need to determine if we are in a property
          }
        }
      }
      case Some(Tokens.FUNCTION) => {
        val functionTokenIndex = tokens.indexWhere((token) => token.text == Tokens.FUNCTION)
        val bodySeparator = tokens.indexWhere((token) => token.text == "=")
        if (bodySeparator > functionTokenIndex) {
          //If the = is after the fun the it means we are in a function body
          fixBodyExpression(location, document, retryCount)
        } else {
          //most probably we are in a parameter
          normalSanitize(location, document)
        }
      }
      case _ => normalSanitize(location, document)
    }
  }

  private def fixBodyExpression(location: Int, document: TextDocument, retryCount: Int): Unit = {
    if (retryCount == 0) {
      normalSanitize(location, document)
    } else {
      //Try to detect unbalance [ or "
      var i = location
      var c = document.charAt(i)
      while (i > 0 && isWhitespace(document.charAt(i))) {
        i = i - 1
      }
      c = document.charAt(i)
      while ((c != '\n' || i == location) && i > 0) {
        c match {
          case ':' => {
            document.insert(location, ',')
            i = -1
          }
          case '(' => {
            document.insert(document.endOfWord(location), ')')
            i = -1
          }
          case '=' => {
            document.insert(i + 1, '"')
            document.insert(document.endOfWord(location), '"')
            i = -1
          }
          case '[' => {
            document.insert(document.endOfWord(location), ']')
            i = -1
          }
          case '"' => {
            document.insert(document.endOfWord(location), '"')
            i = -1
          }
          case '\'' => {
            document.insert(document.endOfWord(location), '\'')
            i = -1
          }
          case '`' => {
            document.insert(document.endOfWord(location), '`')
            i = -1
          }
          case _ => {
            i = i - 1
            c = document.charAt(i)
          }
        }
      }
    }

  }

  private def normalSanitize(location: Int, document: TextDocument): Unit = {
    if (location > 0 && location <= document.length) {
      if (document.charAt(location - 1).isLetter) {
        document.insert(location, " " + DocumentParser.FAKE_VARIABLE_NAME + " ")
      } else {
        document.insert(location, DocumentParser.FAKE_VARIABLE_NAME)
      }
    }
  }

  def parsingPhase(): CompilationPhase[ParsingContentInput, ParsingResult[AstNode]] = {
    ParsingPhase[AstNode](parse)
      .chainWith(new AstNodeVerificationPhase[AstNode, ParsingResult[AstNode]](parsingValidations))
      .chainWith(new ParsingAnnotationProcessorPhase[AstNode, ParsingResult[AstNode]]())
  }

  def parse(input: WeaveResource, parsingContext: ParsingContext): PhaseResult[ParsingResult[AstNode]] = {
    val parserInput = ParsingContentInput(input, parsingContext.nameIdentifier, SafeStringBasedParserInput(input.content()))
    parsingPhase().call(parserInput, parsingContext)
  }

  def canonical(previous: PhaseResult[ParsingResult[AstNode]], parsingContext: ParsingContext): PhaseResult[ParsingResult[_ <: AstNode]] = {
    if (previous.hasResult()) {
      val result = previous.getResult()
      result.astNode match {
        case _: DocumentNode => {
          MappingParser.canonicalPhasePhases().call(result.asInstanceOf[ParsingResult[DocumentNode]], parsingContext)
        }
        case _: ModuleNode => {
          ModuleParser.canonicalPhasePhases().call(result.asInstanceOf[ParsingResult[ModuleNode]], parsingContext)
        }
      }
    } else {
      FailureResult(parsingContext)
    }
  }

  def scopeCheck(previous: PhaseResult[ParsingResult[_]], parsingContext: ParsingContext): PhaseResult[ScopeGraphResult[_ <: AstNode]] = {
    if (previous.hasResult()) {
      val result = previous.getResult()
      result.astNode match {
        case node: DocumentNode => {
          val baseCompiler: CompilationPhase[AstNodeResultAware[DocumentNode] with ParsingContentAware, ScopeGraphResult[DocumentNode]] = MappingParser.scopePhasePhases()
          val mayBeInspector: Option[ScopeCodeInspectorPhase[DocumentNode, ScopeGraphResult[DocumentNode]]] = inspector.scopeInspectionsFor(node)
          val compiler = if (mayBeInspector.isDefined) {
            baseCompiler.chainWith(mayBeInspector.get)
          } else {
            baseCompiler
          }
          compiler.call(result.asInstanceOf[ParsingResult[DocumentNode]], parsingContext)
        }
        case node: ModuleNode => {
          val baseCompiler: CompilationPhase[AstNodeResultAware[ModuleNode] with ParsingContentAware, ScopeGraphResult[ModuleNode]] = ModuleParser.scopePhasePhases()
          val mayBeInspector: Option[ScopeCodeInspectorPhase[ModuleNode, ScopeGraphResult[ModuleNode]]] = inspector.scopeInspectionsFor(node)
          val compiler = if (mayBeInspector.isDefined) {
            baseCompiler.chainWith(mayBeInspector.get)
          } else {
            baseCompiler
          }
          compiler.call(result.asInstanceOf[ParsingResult[ModuleNode]], parsingContext)
        }
      }
    } else {
      FailureResult(parsingContext)
    }
  }

  def typeCheck(previous: PhaseResult[ScopeGraphResult[_]], parsingContext: ParsingContext): PhaseResult[TypeCheckingResult[_ <: AstNode]] = {
    if (previous.hasResult()) {
      val result = previous.getResult()
      result.astNode match {
        case node: DocumentNode => {
          val baseCompiler = MappingParser.typeCheckPhasePhases()
          val mayBeInspector = inspector.typeInspectionsFor(node)
          val compiler = if (mayBeInspector.isDefined) {
            baseCompiler.chainWith(mayBeInspector.get)
          } else {
            baseCompiler
          }
          compiler.call(result.asInstanceOf[ScopeGraphResult[DocumentNode]], parsingContext)
        }
        case node: ModuleNode => {
          val baseCompiler = ModuleParser.typeCheckPhasePhases()
          val mayBeInspector = inspector.typeInspectionsFor(node)
          val compiler = if (mayBeInspector.isDefined) {
            baseCompiler.chainWith(mayBeInspector.get)
          } else {
            baseCompiler
          }
          compiler.call(result.asInstanceOf[ScopeGraphResult[ModuleNode]], parsingContext)
        }
      }
    } else {
      FailureResult(parsingContext)
    }
  }

  def runScopePhases(weaveResource: WeaveResource, parsingContext: ParsingContext): PhaseResult[ScopeGraphResult[_ <: AstNode]] = {
    scopeCheck(canonical(parse(weaveResource, parsingContext), parsingContext), parsingContext)
  }

  def runAllPhases(weaveResource: WeaveResource, parsingContext: ParsingContext): PhaseResult[TypeCheckingResult[_ <: AstNode]] = {
    typeCheck(runScopePhases(weaveResource, parsingContext), parsingContext)
  }

  def reverseTypeCheck(previous: PhaseResult[TypeCheckingResult[_]], parsingContext: ParsingContext): PhaseResult[TypeCheckingResult[_ <: AstNode]] = {
    if (previous.hasResult()) {
      val result = previous.getResult()
      result.astNode match {
        case dn: DocumentNode => {
          val phase = new ReverseTypeCheckingPhase[DocumentNode]()
          phase.call(result.asInstanceOf[TypeCheckingResult[DocumentNode]], parsingContext)
        }
        case mn: ModuleNode => {
          val phase = new ReverseTypeCheckingPhase[ModuleNode]()
          phase.call(result.asInstanceOf[TypeCheckingResult[ModuleNode]], parsingContext)
        }
      }
    } else {
      FailureResult(parsingContext)
    }
  }

  private def parse(parserInput: ParsingContentInput, context: ParsingContext): Either[ParseError, AstNode] = {
    val parser = new Grammar(parserInput.input, parserInput.nameIdentifier, errorTrace)
    parser.weavedocument.run()
  }

}

class TextDocument(initialContent: String) {

  val content = new StringBuilder(initialContent)

  def delete(startLocation: Int, endLocation: Int): Unit = {
    content.delete(startLocation, endLocation)
  }

  def endOfWord(location: Int): Int = {
    var i = location
    while (i < length && !(isWhitespace(i) || isToken(i))) {
      i = i + 1
    }
    i
  }

  def trimLength(): Int = content.toString().trim.length

  def textUntil(location: Int): String = {
    if (location > 0 && text().length > location) {
      text().substring(0, location)
    } else {
      text()
    }
  }

  def wordStart(index: Int): Int = {
    var i = index
    while (i > 0 && !isWhitespace(i - 1) && !isToken(i - 1)) {
      i = i - 1
    }
    i
  }

  def isWhitespace(index: Int): Boolean = {
    if (content.length > index && index > 0) {
      content.charAt(index).isWhitespace
    } else {
      false
    }
  }

  def isToken(index: Int): Boolean = {
    if (content.length > index && index > 0) {
      val c = content.charAt(index)
      DocumentParser.SPECIAL_TOKENS.contains(c)
    } else {
      false
    }
  }

  def length: Int = content.size

  def insert(location: Int, content: String): Unit = {
    if (location > this.content.length) {
      val missing = location - this.content.length
      val extraSpaces = " " * missing
      this.content.append(extraSpaces + content)
    } else if (location >= 0) {
      this.content.insert(location, content)
    }
  }

  def insert(location: Int, content: Char): Unit = {
    this.content.insert(location, content)
  }

  def charAt(i: Int): Char = {
    if (i < 0 || i > length) {
      ' '
    } else {
      content.charAt(i)
    }
  }

  /**
    * Returns the content of the line at a given offset
    *
    * @param offset
    * @return
    */
  def getLineContentOf(offset: Int): String = {
    val result = new StringBuilder
    var i = Math.min(offset, content.length - 1)
    var currentChar: Char = content.charAt(i)
    while (i > 0 && currentChar != '\n') {
      result.insert(0, currentChar)
      currentChar = content.charAt(i)
      i = i - 1
    }
    result.toString()
  }

  def text(): String = {
    content.toString()
  }

  def text(start: Int, end: Int): String = {
    val safeStart = Math.min(Math.max(start, 0), length)
    val safeEnd = Math.min(Math.max(end, 0), length)
    content.substring(safeStart, safeEnd)
  }
}

case class Token(text: String, start: Int, end: Int)

object DocumentParser {

  val FAKE_VARIABLE_NAME = NameIdentifier.INSERTED_FAKE_VARIABLE_NAME
  val TOKEN_DIRECTIVE = Seq(Tokens.OUTPUT, Tokens.INPUT, Tokens.VAR, Tokens.FUNCTION, Tokens.TYPE, Tokens.NS, Tokens.DOCUMENT_SEPARATOR)
  val TOKENS = Seq(' ', '\n', '=', ',')
  val SPECIAL_TOKENS = Seq(')', '(', ']', '[', '{', '}', '.', '/', '+', '-', '*', '@', '~', '^')

  def apply(errorTrace: Int = 2, inspector: Inspector = NoInspector): DocumentParser = new DocumentParser(errorTrace, inspector)
}
