package org.mule.weave.v2.debugger.client.tcp

import java.io.IOException
import java.io.ObjectInputStream
import java.io.ObjectOutputStream
import java.net.InetSocketAddress
import java.net.Socket
import java.net.SocketException

import org.mule.weave.v2.debugger.client.ClientProtocol
import org.mule.weave.v2.debugger.client.ServerMessageHandler
import org.mule.weave.v2.debugger.commands.ClientCommand
import org.mule.weave.v2.debugger.event.RemoteServerMessage
import org.mule.weave.v2.debugger.event.UnexpectedServerErrorEvent

import scala.collection.mutable.ArrayBuffer
import scala.util.Try

class TcpClientProtocol(host: String = "localhost", port: Int = TcpClientProtocol.DEFAULT_PORT) extends ClientProtocol {
  var client: Socket = _

  val handlers = ArrayBuffer[(Class[_], ServerMessageHandler[_])]()

  def waitForEvent(): Option[RemoteServerMessage] = {
    if (isClientConnected) {
      try {
        val stream: ObjectInputStream = new ObjectInputStream(client.getInputStream)
        val readObject: AnyRef = stream.readObject()
        Some(readObject.asInstanceOf[RemoteServerMessage])
      } catch {
        case se: SocketException => {
          disconnect()
          None
        }
        case e: IOException => {
          println("Disconnecting client as server is down")
          disconnect()
          None
        }
        case e: Exception => {
          e.printStackTrace()
          None
        }
      }
    } else {
      None
    }
  }

  override def sendCommand[ContextType, RemoteServerResponse <: RemoteServerMessage](command: ClientCommand[ContextType, RemoteServerResponse]): String = {
    if (isClientConnected) {
      synchronized({
        if (isClientConnected) {
          try {
            val stream: ObjectOutputStream = new ObjectOutputStream(client.getOutputStream)
            stream.reset()
            stream.writeObject(command)
            stream.flush()
          } catch {
            case io: IOException => {
              disconnect()
            }
          }
        }
      })
    }
    command.id
  }

  def isClientConnected: Boolean = {
    client != null && !client.isClosed && client.isConnected && !client.isInputShutdown
  }

  override def connect(): Unit = {
    if (!isClientConnected) {
      synchronized({
        if (!isClientConnected) {
          client = new Socket()
          println("Trying to connect to " + host + " " + port)
          client.connect(new InetSocketAddress(host, port))
          val debuggerClient: Thread = new Thread(new DebuggerEventListener(), "Weave TCP Client Poll")
          debuggerClient.setDaemon(true)
          debuggerClient.start()
        }
      })
    }
  }

  override def disconnect(): Unit = {
    if (isClientConnected) {
      synchronized({
        if (isClientConnected) {
          Try(client.close())
          client = null
        }
      })
    }
  }

  override def isConnected(): Boolean = isClientConnected

  override def addServerMessageHandler[T <: RemoteServerMessage](clazz: Class[_], handler: ServerMessageHandler[T]): Unit = {
    handlers.+=((clazz, handler))
  }

  class DebuggerEventListener() extends Runnable {

    def dispatchEvent(forEvent: RemoteServerMessage): Unit = {
      forEvent match {
        case ue: UnexpectedServerErrorEvent => {
          handlers.foreach(_._2.handleUnexpectedError(ue))
        }
        case _ => {
          handlers.find((handler) => {
            handler._1.isAssignableFrom(forEvent.getClass)
          }).foreach((handler) => {
            Try({
              val value = handler._2
              value.handle(forEvent.asInstanceOf[value.TypeOfMessage])
            })
          })
        }
      }

    }

    override def run(): Unit = {
      while (isConnected()) {
        val forEvent = waitForEvent()
        if (forEvent.isDefined) {
          Try(dispatchEvent(forEvent.get))
        }
      }
      if (client != null) {
        client.close()
        client = null
      }
    }
  }

}

object TcpClientProtocol {

  val DEFAULT_PORT = 6565

  def apply(host: String = "localhost", port: Int = TcpClientProtocol.DEFAULT_PORT) = new TcpClientProtocol(host, port)
}

