package uk.neilgall.kanren

typealias Substitutions = Map<out Term, Term>
typealias Goal = (State) -> Sequence<State>

data class State(val subs: Substitutions = mapOf(), val vars: Int = 0) {

  fun adding(newSubs: Substitutions): State {
    return State(subs + newSubs, vars)
  }

  fun withNewVar(f: (Term) -> Goal): Sequence<State> {
    val newVar = Term.Variable(vars)
    val newState = State(subs, vars + 1)
    val goal = f(newVar)
    return goal(newState)
  }

  override fun toString(): String =
    "[" + subs.map { "${it.key} = ${it.value}" }.joinToString() + "]"
}

fun State.walk(t: Term): Term {
  fun substitute(term: Term): Term {
    val sub = subs[term]
    return if (sub != null) walk(sub) else term
  }

  return when (t) {
    is Term.Variable -> substitute(t)
    is Term.Pair -> Term.Pair(substitute(t.p), substitute(t.q))
    is Term.BinaryExpr -> t.op.evaluate(substitute(t.lhs), substitute(t.rhs))
      ?: t
    else -> t
  }
}

fun State.unifyExpr(lhs: Term, rhs: Term): Sequence<Substitutions>? {
  val unified: Sequence<Substitutions> = sequenceOf(mapOf())

  return when {
    lhs is Term.None && rhs is Term.None -> unified
    lhs is Term.String && rhs is Term.String -> if (lhs.s == rhs.s) unified else null
    lhs is Term.Int && rhs is Term.Int -> if (lhs.i == rhs.i) unified else null
    lhs is Term.Boolean && rhs is Term.Boolean -> if (lhs.b == rhs.b) unified else null
    lhs is Term.Pair && rhs is Term.Pair -> {
      val ps = unifyExpr(lhs.p, rhs.p)
      val qs = unifyExpr(lhs.q, rhs.q)
      if (qs == null) null else ps?.flatMap { p -> qs.map { q -> p + q } }
    }
    lhs is Term.Variable -> sequenceOf(mapOf(lhs to rhs))
    rhs is Term.Variable -> sequenceOf(mapOf(rhs to lhs))
    lhs is Term.BinaryExpr -> when {
      lhs.lhs is Term.Variable -> lhs.op.reverseEvaluateLHS(lhs.rhs, rhs)?.map { mapOf(lhs.lhs to it) }
      lhs.rhs is Term.Variable -> lhs.op.reverseEvaluateRHS(lhs.lhs, rhs)?.map { mapOf(lhs.rhs to it) }
      else -> null
    }
    rhs is Term.BinaryExpr -> when {
      rhs.lhs is Term.Variable -> rhs.op.reverseEvaluateLHS(rhs.rhs, lhs)?.map { mapOf(rhs.lhs to it) }
      rhs.rhs is Term.Variable -> rhs.op.reverseEvaluateRHS(rhs.lhs, lhs)?.map { mapOf(rhs.rhs to it) }
      else -> null
    }
    else -> null
  }
}

fun State.unify(lhs: Term, rhs: Term): Sequence<State> =
  unifyExpr(walk(lhs), walk(rhs))?.map { adding(it) } ?: sequenceOf()

fun State.disunify(lhs: Term, rhs: Term): Sequence<State> =
  unifyExpr(walk(lhs), walk(rhs))?.flatMap { sequenceOf() } ?: sequenceOf(this)

fun <T: Any> lazy(t: T): Sequence<T> = generateSequence({ t }, { null })

// Ensure the sequence generated by a goal is lazy
fun zzz(g: Goal): Goal = { state ->
  generateSequence(
    { val it = g(state).iterator(); if (it.hasNext()) it else null },
    { if (it.hasNext()) it else null }
  ).map { it.next() }
}