package org.mule.weave.v2.debugger.client

import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit

import org.mule.weave.v2.debugger.WeaveBreakpoint
import org.mule.weave.v2.debugger.WeaveExceptionBreakpoint
import org.mule.weave.v2.debugger.client.tcp.TcpClientProtocol
import org.mule.weave.v2.debugger.commands.AddBreakpointCommand
import org.mule.weave.v2.debugger.commands.AddBreakpointsCommand
import org.mule.weave.v2.debugger.commands.AddExceptionBreakpointsCommand
import org.mule.weave.v2.debugger.commands.ClearBreakpointsCommand
import org.mule.weave.v2.debugger.commands.DebuggerCommand
import org.mule.weave.v2.debugger.commands.EvaluateScriptCommand
import org.mule.weave.v2.debugger.commands.InitializeSessionCommand
import org.mule.weave.v2.debugger.commands.NextStepDebuggerCommand
import org.mule.weave.v2.debugger.commands.RemoveBreakpointCommand
import org.mule.weave.v2.debugger.commands.ResumeDebuggerCommand
import org.mule.weave.v2.debugger.commands.StepIntoDebuggerCommand
import org.mule.weave.v2.debugger.commands.StepOutDebuggerCommand
import org.mule.weave.v2.debugger.event.BreakpointAddedEvent
import org.mule.weave.v2.debugger.event.BreakpointRemovedEvent
import org.mule.weave.v2.debugger.event.BreakpointsAddedEvent
import org.mule.weave.v2.debugger.event.BreakpointsCleanedEvent
import org.mule.weave.v2.debugger.event.ClientInitializedEvent
import org.mule.weave.v2.debugger.event.DebuggerEvent
import org.mule.weave.v2.debugger.event.ExceptionBreakpointsAddedEvent
import org.mule.weave.v2.debugger.event.NextStepDebuggerEvent
import org.mule.weave.v2.debugger.event.OnFrameEvent
import org.mule.weave.v2.debugger.event.RemoteServerMessage
import org.mule.weave.v2.debugger.event.ResumeDebuggerEvent
import org.mule.weave.v2.debugger.event.ScriptResultEvent
import org.mule.weave.v2.debugger.event.StepIntoDebuggerEvent
import org.mule.weave.v2.debugger.event.StepOutDebuggerEvent
import org.mule.weave.v2.debugger.event.UnexpectedServerErrorEvent

import scala.collection.mutable

class DebuggerClient(defaultEventHandler: DebuggerClientListener, protocol: ClientProtocol = new TcpClientProtocol()) {

  private val breakpoints: mutable.ListBuffer[WeaveBreakpoint] = mutable.ListBuffer()
  private val exceptionBreakpoints: mutable.ListBuffer[WeaveExceptionBreakpoint] = mutable.ListBuffer()

  private val eventHandlers: mutable.ListBuffer[DebuggerEventHandler] = mutable.ListBuffer()

  //Register default event handler
  eventHandlers.+=(new DefaultDebuggerEventHandler(defaultEventHandler, this))

  def connect(): FutureResponse[ClientInitializedEvent] = {

    protocol.addServerMessageHandler(
      classOf[DebuggerEvent],
      new ServerMessageHandler[DebuggerEvent] {
        override def handle(message: DebuggerEvent): Unit = {
          dispatchEvent(message)
        }

        override def handleUnexpectedError(unexpectedServerErrorEvent: UnexpectedServerErrorEvent): Unit = {
          dispatchEvent(unexpectedServerErrorEvent)
        }
      })

    if (!protocol.isConnected()) {
      protocol.connect()
    }
    val sessionCommand = new InitializeSessionCommand(breakpoints.toArray, exceptionBreakpoints.toArray)
    val response = FutureResponse[ClientInitializedEvent](sessionCommand)
    //Add breakpoints that were added before connecting
    protocol.sendCommand(sessionCommand)
    breakpoints.clear()
    exceptionBreakpoints.clear()
    response
  }

  def isConnected(): Boolean = {
    protocol.isConnected()
  }

  def resume(): FutureResponse[ResumeDebuggerEvent] = {
    val command = new ResumeDebuggerCommand()
    val result = FutureResponse[ResumeDebuggerEvent](command)
    protocol.sendCommand(command)
    result
  }

  def nextStep(): FutureResponse[NextStepDebuggerEvent] = {
    val command = new NextStepDebuggerCommand()
    val result = FutureResponse[NextStepDebuggerEvent](command)
    protocol.sendCommand(command)
    result
  }

  def stepInto(): FutureResponse[StepIntoDebuggerEvent] = {
    val command = new StepIntoDebuggerCommand()
    val result = FutureResponse[StepIntoDebuggerEvent](command)
    protocol.sendCommand(command)
    result
  }

  def stepOut(): FutureResponse[StepOutDebuggerEvent] = {
    val command = new StepOutDebuggerCommand()
    val result = FutureResponse[StepOutDebuggerEvent](command)
    protocol.sendCommand(command)
    result
  }

  def addBreakpoint(breakpoint: WeaveBreakpoint): Any = {
    if (protocol.isConnected()) {
      protocol.sendCommand(new AddBreakpointCommand(breakpoint))
    } else {
      breakpoints.+=(breakpoint)
    }
  }

  def addBreakpoints(breakpoint: Array[WeaveBreakpoint]): Option[FutureResponse[BreakpointsAddedEvent]] = {
    if (protocol.isConnected()) {
      val breakpointsCommand = new AddBreakpointsCommand(breakpoint)
      val result = FutureResponse[BreakpointsAddedEvent](breakpointsCommand)
      protocol.sendCommand(breakpointsCommand)
      Some(result)
    } else {
      breakpoints.++=(breakpoint)
      None
    }
  }

  def addExceptionBreakpoints(breakpoints: Array[WeaveExceptionBreakpoint]): Option[FutureResponse[ExceptionBreakpointsAddedEvent]] = {
    if (protocol.isConnected()) {
      val breakpointsCommand = new AddExceptionBreakpointsCommand(breakpoints)
      val result = FutureResponse[ExceptionBreakpointsAddedEvent](breakpointsCommand)
      protocol.sendCommand(breakpointsCommand)
      Some(result)
    } else {
      exceptionBreakpoints.++=(breakpoints)
      None
    }
  }

  def removeBreakpoint(breakpoint: WeaveBreakpoint): Option[String] = {
    if (protocol.isConnected()) {
      Some(protocol.sendCommand(new RemoveBreakpointCommand(breakpoint)))
    } else {
      breakpoints.-=(breakpoint)
      None
    }
  }

  def evaluateScript(id: Int, script: String, callback: ScriptEvaluationListener): Unit = {
    if (protocol.isConnected()) {
      //Listen for the evaluation results
      eventHandlers.prepend(new DebuggerEventHandler {
        override def accepts(event: RemoteServerMessage): Boolean = {
          event match {
            case s: ScriptResultEvent => s.script.equals(script)
            case _                    => false
          }
        }

        override def handle(event: RemoteServerMessage): Unit = {
          event match {
            case s: ScriptResultEvent => callback.onScriptEvaluated(DebuggerClient.this, s)
            case _                    =>
          }
          //We remove it once is handled
          eventHandlers.-=(this)
        }
      })
      protocol.sendCommand(new EvaluateScriptCommand(script, id))
    }
  }

  def clearBreakpoints(): Unit = {
    if (protocol.isConnected()) {
      protocol.sendCommand(new ClearBreakpointsCommand())
    }
  }

  def disconnect(): Unit = {
    protocol.disconnect()
  }

  private def dispatchEvent(forEvent: RemoteServerMessage): Unit = {
    val acceptedHandlers = eventHandlers.filter((handler) => handler.accepts(forEvent))
    acceptedHandlers.foreach(_.handle(forEvent))
  }

  class FutureResponse[T <: DebuggerEvent](commandId: String) {

    var result: Option[T] = None
    var listener: (T) => Unit = _

    eventHandlers.prepend(new DebuggerEventHandler {
      override def accepts(event: RemoteServerMessage): Boolean = {
        event.commandId match {
          case Some(value) => {
            value == commandId
          }
          case None => false
        }
      }

      override def handle(event: RemoteServerMessage): Unit = {
        eventHandlers.-=(this)
        if (listener != null) {
          listener(event.asInstanceOf[T])
        }
        result = Some(event.asInstanceOf[T])
        countDownLatch.countDown()

      }
    })

    val countDownLatch = new CountDownLatch(1)

    def onResponse(listener: (T) => Unit): Unit = {
      this.listener = listener
      if (result.isDefined) {
        listener(result.get)
      }
    }

    def waitResponse(): T = {
      if (result.isDefined) {
        result.get
      } else {
        countDownLatch.await(10, TimeUnit.MINUTES)
        result.get
      }
    }
  }

  object FutureResponse {
    def apply[T <: DebuggerEvent](command: DebuggerCommand): FutureResponse[T] = new FutureResponse(command.id)
  }

}

/**
  * Handles debugger command responses
  */
trait DebuggerEventHandler {

  def accepts(event: RemoteServerMessage): Boolean

  def handle(event: RemoteServerMessage): Unit
}

class DefaultDebuggerEventHandler(defaultListener: DebuggerClientListener, debuggerClient: DebuggerClient) extends DebuggerEventHandler {
  override def accepts(event: RemoteServerMessage): Boolean = true

  override def handle(event: RemoteServerMessage): Unit = {
    try {
      event match {
        case bae: BreakpointAddedEvent          => defaultListener.onBreakpointAdded(bae)
        case bsae: BreakpointsAddedEvent        => defaultListener.onBreakpointsAdded(bsae)
        case bre: BreakpointRemovedEvent        => defaultListener.onBreakpointRemoved(bre)
        case bce: BreakpointsCleanedEvent       => defaultListener.onBreakpointCleaned(bce)
        case fe: OnFrameEvent                   => defaultListener.onFrame(debuggerClient, fe)
        case sr: ScriptResultEvent              => defaultListener.onScriptEvaluated(debuggerClient, sr)
        case ci: ClientInitializedEvent         => defaultListener.onClientInitialized(ci)
        case ue: UnexpectedServerErrorEvent     => defaultListener.onUnexpectedError(ue)
        case nxt: NextStepDebuggerEvent         => defaultListener.onNextStepExecuted(nxt)
        case sie: StepIntoDebuggerEvent         => defaultListener.onStepIntoExecuted(sie)
        case sie: StepOutDebuggerEvent          => defaultListener.onStepOutExecuted(sie)
        case rde: ResumeDebuggerEvent           => defaultListener.onResumeExecuted(rde)
        case ex: ExceptionBreakpointsAddedEvent => defaultListener.onExceptionBreakpointsAdded(ex)
      }
    } catch {
      case e: Exception => e.printStackTrace()
    }
  }
}
