Skip to content

Instantly share code, notes, and snippets.

@magicoal-nerb
Last active November 12, 2025 09:02
Show Gist options
  • Save magicoal-nerb/98090e48e21f7c904ba9998a32a43551 to your computer and use it in GitHub Desktop.
Save magicoal-nerb/98090e48e21f7c904ba9998a32a43551 to your computer and use it in GitHub Desktop.
burkhard keller tree
--!strict
-- BK trees are used for fast string lookup queries
-- And yea u have to use levenshtein
-- Basically, I just knocked down a few constants
-- so it is marginally faster than most implementations.
-- But, asymptotic performance is basically the same so yeah
-- https://gist.github.com/magicoal-nerb/6c91120e671d557de59e6d87e6617868
local AvlTree = require("./AvlTree")
-- https://gist.github.com/magicoal-nerb/676b02597d3ab859b3e9f062ef5280c1
local Queue = require("./Queue")
local BkTree = {}
BkTree.__index = BkTree
export type BkDistance = (string, string, number) -> number
export type BkNode = {
value: string,
distance: number,
children: AvlTree.Avl<number, BkNode>,
}
export type BkTree<T> = typeof(setmetatable({} :: {
distance: BkDistance,
dataset: { [string]: T },
root: BkNode?,
}, BkTree))
local function compare(a: number, b: number)
return a < b
end
local function rangeQuery(
tree: AvlTree.Avl<number, BkNode>,
distanceA: number,
distanceB: number,
queue: Queue.Queue<BkNode>
)
-- Inorder traversal goes from left -> root -> right
local stack = {}
local count = 1
-- Query the minimum first
local current = tree.root
while current do
table.insert(stack, current)
count += 1
local distance = (current.data :: BkNode).distance
if distance > distanceA then
current = current.left
else
current = current.right
end
end
while current or count > 0 do
while current do
-- Keep on going left
table.insert(stack, current)
count += 1
current = current.left
end
current = table.remove(stack)
count -= 1
-- We can proceed to the right if its
-- possible
if not current then
continue
end
local distance = current.data.distance
if distance > distanceB then
-- Greatest distance away
return
elseif distance < distanceA then
-- Go right still
current = current.right
continue
else
-- Otherwise, we add to the queue
--assert(distance >= distanceA and distance <= distanceB)
queue:enqueue(current.data)
current = current.right
end
end
end
function BkTree.new<T>(
dataset: { [string]: T },
distance: BkDistance
)
-- Bk tree constructor lol
local self = setmetatable({
distance = distance,
dataset = dataset,
}, BkTree)
for name, id in dataset do
self:insert(name)
end
return self
end
function BkTree.getEntry<T>(self: BkTree<T>, key: string): T
-- Gets the entry from the BK tree
return self.dataset[key]
end
function BkTree.insert<T>(self: BkTree<T>, name: string)
if not self.root then
-- Create the root
self.root = {
value = name,
distance = 0.0,
children = AvlTree.new(compare),
}
return
end
local distance = self.distance
local node = self.root
while node do
local children = node.children
local cost = distance(name, node.value, math.huge)
local queryNode = AvlTree.get(children, cost)
if queryNode then
-- We check the qurey
node = queryNode
else
-- We add it into the child AVL tree
AvlTree.insert(children, cost, {
value = name,
distance = cost,
children = AvlTree.new(compare),
})
break
end
end
end
function BkTree.query<T>(self: BkTree<T>, name: string, n: number): { string }
-- Create a queue
local data = Queue.new() :: Queue.Queue<BkNode>
data:enqueue(self.root :: BkNode)
local distance = self.distance
local output = {}
while not data:empty() do
local node = data:dequeue()
local cost = self.distance(name, node.value, n)
if math.abs(cost) <= n then
-- Reached our threshold
table.insert(output, node.value)
end
rangeQuery(node.children, cost - n, cost + n, data)
end
-- Memoize the list of members we collected
local sorting = {}
local memo = {}
for i, str in output do
if memo[str] then
continue
end
memo[str] = distance(name, str, n)
table.insert(sorting, str)
end
table.sort(sorting, function(a: string, b: string)
return memo[a] < memo[b]
end)
return sorting
end
return BkTree
--!strict
return function(a: string, b: string, dist: number): number
local m = #b
local n = #a
-- Basic levenshtein distance; make buffers
local v0 = table.create(n + 1, 0)
local v1 = table.create(n + 1, 0)
for i = 1, n + 1 do
v0[i] = i - 1
end
for i = 1, m do
-- Loop through the row and check for
-- insertion or deletion cost
v1[1] = i
local min = i
for j = 1, n do
local cost = math.min(v0[j + 1] + 1, v1[j] + 1)
if string.byte(a, j) == string.byte(b, i) then
cost = math.min(cost, v0[j])
else
cost = math.min(cost, v0[j] + 1)
end
-- n+1 should have least cost so far
min = math.min(min, cost)
v1[j + 1] = cost
end
-- swap rows
v0, v1 = v1, v0
end
return v0[n + 1]
end
--!strict
-- https://siderite.dev/blog/super-fast-and-accurate-string-distance.html
return function(a: string, b: string, max: number)
local lenA = #a
local lenB = #b
local curA = 1
local curB = 1
local largestCommon = 0
local localCommon = 0
while curA <= lenA and curB <= lenB do
if string.byte(a, curA) == string.byte(b, curB) then
localCommon += 1
curA += 1
curB += 1
continue
end
largestCommon += localCommon
localCommon = 0
if curA ~= curB then
local max = math.max(curA, curB)
curA = max
curB = max
end
local i = 0
while i < max and (curA + i <= lenA or curB + i <= lenB) do
if curA + i <= lenA and string.byte(a, curA + i) == string.byte(b, curB) then
curA += i
localCommon += 1
break
elseif curB + i <= lenB and string.byte(a, curA) == string.byte(b, curB + i) then
curB += i
localCommon += 1
break
end
i += 1
end
curA += 1
curB += 1
end
largestCommon += localCommon
return math.max(lenA, lenB) - largestCommon
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment