package net.peanuuutz.fork.util.common.diff

import net.peanuuutz.fork.util.common.fastForEachIndexedReversed

public val MyersDiff: DiffAlgorithm = object : DiffAlgorithm {
    override fun Differ.execute(oldSize: Int, newSize: Int) {
        executeImpl(
            oldSize = oldSize,
            newSize = newSize
        )
    }

    override fun toString(): String {
        return "MyersDiff"
    }
}

// ======== Internal ========

private typealias Record = IntArray

private fun Differ.executeImpl(oldSize: Int, newSize: Int) {
    val records = evaluateDiff(
        oldSize = oldSize,
        newSize = newSize
    )
    applyDiff(
        records = records,
        oldSize = oldSize,
        newSize = newSize
    )
}

private fun Differ.evaluateDiff(oldSize: Int, newSize: Int): Array<Record?> {
    val max = oldSize + newSize
    val negatableRecord = NegatableRecord(max * 2 + 1)
    val records = arrayOfNulls<Record>(max + 1)
    var d = 0
    outer@ while (d <= max) {
        records[d] = negatableRecord.record.copyOf()
        var k = -d
        while (k <= d) {
            var x = if (k == -d || k != d && negatableRecord[k - 1] < negatableRecord[k + 1]) {
                negatableRecord[k + 1]
            } else {
                negatableRecord[k - 1] + 1
            }
            var y = x - k
            while (x < oldSize && y < newSize && compare(x, y)) {
                x++
                y++
            }
            negatableRecord[k] = x
            if (x >= oldSize && y >= newSize) {
                break@outer
            }
            k += 2
        }
        d++
    }
    return records
}

private fun Differ.applyDiff(records: Array<Record?>, oldSize: Int, newSize: Int) {
    var x = oldSize
    var y = newSize
    records.fastForEachIndexedReversed { d, record ->
        if (record == null) {
            return@fastForEachIndexedReversed
        }
        val negatableRecord = NegatableRecord(record)
        val k = x - y
        val previousK = if (k == -d || k != d && negatableRecord[k - 1] < negatableRecord[k + 1]) {
            k + 1
        } else {
            k - 1
        }
        val previousX = negatableRecord[previousK]
        val previousY = previousX - previousK
        while (x > previousX && y > previousY) {
            x--
            y--
            update(x, y)
        }
        if (d > 0) {
            if (x != previousX) {
                x--
                remove(x)
            }
            if (y != previousY) {
                y--
                insert(y)
            }
        }
    }
}

@JvmInline
private value class NegatableRecord(val record: Record) {
    constructor(size: Int) : this(Record(size))

    operator fun get(index: Int): Int {
        return record[index.negatable]
    }

    operator fun set(index: Int, value: Int) {
        record[index.negatable] = value
    }

    private inline val Int.negatable: Int
        get() = if (this >= 0) this else record.size + this
}
