Last active
November 12, 2025 09:02
-
-
Save magicoal-nerb/98090e48e21f7c904ba9998a32a43551 to your computer and use it in GitHub Desktop.
burkhard keller tree
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| --!strict | |
| -- BK trees are used for fast string lookup queries | |
| -- And yea u have to use levenshtein | |
| -- Basically, I just knocked down a few constants | |
| -- so it is marginally faster than most implementations. | |
| -- But, asymptotic performance is basically the same so yeah | |
| -- https://gist.github.com/magicoal-nerb/6c91120e671d557de59e6d87e6617868 | |
| local AvlTree = require("./AvlTree") | |
| -- https://gist.github.com/magicoal-nerb/676b02597d3ab859b3e9f062ef5280c1 | |
| local Queue = require("./Queue") | |
| local BkTree = {} | |
| BkTree.__index = BkTree | |
| export type BkDistance = (string, string, number) -> number | |
| export type BkNode = { | |
| value: string, | |
| distance: number, | |
| children: AvlTree.Avl<number, BkNode>, | |
| } | |
| export type BkTree<T> = typeof(setmetatable({} :: { | |
| distance: BkDistance, | |
| dataset: { [string]: T }, | |
| root: BkNode?, | |
| }, BkTree)) | |
| local function compare(a: number, b: number) | |
| return a < b | |
| end | |
| local function rangeQuery( | |
| tree: AvlTree.Avl<number, BkNode>, | |
| distanceA: number, | |
| distanceB: number, | |
| queue: Queue.Queue<BkNode> | |
| ) | |
| -- Inorder traversal goes from left -> root -> right | |
| local stack = {} | |
| local count = 1 | |
| -- Query the minimum first | |
| local current = tree.root | |
| while current do | |
| table.insert(stack, current) | |
| count += 1 | |
| local distance = (current.data :: BkNode).distance | |
| if distance > distanceA then | |
| current = current.left | |
| else | |
| current = current.right | |
| end | |
| end | |
| while current or count > 0 do | |
| while current do | |
| -- Keep on going left | |
| table.insert(stack, current) | |
| count += 1 | |
| current = current.left | |
| end | |
| current = table.remove(stack) | |
| count -= 1 | |
| -- We can proceed to the right if its | |
| -- possible | |
| if not current then | |
| continue | |
| end | |
| local distance = current.data.distance | |
| if distance > distanceB then | |
| -- Greatest distance away | |
| return | |
| elseif distance < distanceA then | |
| -- Go right still | |
| current = current.right | |
| continue | |
| else | |
| -- Otherwise, we add to the queue | |
| --assert(distance >= distanceA and distance <= distanceB) | |
| queue:enqueue(current.data) | |
| current = current.right | |
| end | |
| end | |
| end | |
| function BkTree.new<T>( | |
| dataset: { [string]: T }, | |
| distance: BkDistance | |
| ) | |
| -- Bk tree constructor lol | |
| local self = setmetatable({ | |
| distance = distance, | |
| dataset = dataset, | |
| }, BkTree) | |
| for name, id in dataset do | |
| self:insert(name) | |
| end | |
| return self | |
| end | |
| function BkTree.getEntry<T>(self: BkTree<T>, key: string): T | |
| -- Gets the entry from the BK tree | |
| return self.dataset[key] | |
| end | |
| function BkTree.insert<T>(self: BkTree<T>, name: string) | |
| if not self.root then | |
| -- Create the root | |
| self.root = { | |
| value = name, | |
| distance = 0.0, | |
| children = AvlTree.new(compare), | |
| } | |
| return | |
| end | |
| local distance = self.distance | |
| local node = self.root | |
| while node do | |
| local children = node.children | |
| local cost = distance(name, node.value, math.huge) | |
| local queryNode = AvlTree.get(children, cost) | |
| if queryNode then | |
| -- We check the qurey | |
| node = queryNode | |
| else | |
| -- We add it into the child AVL tree | |
| AvlTree.insert(children, cost, { | |
| value = name, | |
| distance = cost, | |
| children = AvlTree.new(compare), | |
| }) | |
| break | |
| end | |
| end | |
| end | |
| function BkTree.query<T>(self: BkTree<T>, name: string, n: number): { string } | |
| -- Create a queue | |
| local data = Queue.new() :: Queue.Queue<BkNode> | |
| data:enqueue(self.root :: BkNode) | |
| local distance = self.distance | |
| local output = {} | |
| while not data:empty() do | |
| local node = data:dequeue() | |
| local cost = self.distance(name, node.value, n) | |
| if math.abs(cost) <= n then | |
| -- Reached our threshold | |
| table.insert(output, node.value) | |
| end | |
| rangeQuery(node.children, cost - n, cost + n, data) | |
| end | |
| -- Memoize the list of members we collected | |
| local sorting = {} | |
| local memo = {} | |
| for i, str in output do | |
| if memo[str] then | |
| continue | |
| end | |
| memo[str] = distance(name, str, n) | |
| table.insert(sorting, str) | |
| end | |
| table.sort(sorting, function(a: string, b: string) | |
| return memo[a] < memo[b] | |
| end) | |
| return sorting | |
| end | |
| return BkTree |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| --!strict | |
| return function(a: string, b: string, dist: number): number | |
| local m = #b | |
| local n = #a | |
| -- Basic levenshtein distance; make buffers | |
| local v0 = table.create(n + 1, 0) | |
| local v1 = table.create(n + 1, 0) | |
| for i = 1, n + 1 do | |
| v0[i] = i - 1 | |
| end | |
| for i = 1, m do | |
| -- Loop through the row and check for | |
| -- insertion or deletion cost | |
| v1[1] = i | |
| local min = i | |
| for j = 1, n do | |
| local cost = math.min(v0[j + 1] + 1, v1[j] + 1) | |
| if string.byte(a, j) == string.byte(b, i) then | |
| cost = math.min(cost, v0[j]) | |
| else | |
| cost = math.min(cost, v0[j] + 1) | |
| end | |
| -- n+1 should have least cost so far | |
| min = math.min(min, cost) | |
| v1[j + 1] = cost | |
| end | |
| -- swap rows | |
| v0, v1 = v1, v0 | |
| end | |
| return v0[n + 1] | |
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| --!strict | |
| -- https://siderite.dev/blog/super-fast-and-accurate-string-distance.html | |
| return function(a: string, b: string, max: number) | |
| local lenA = #a | |
| local lenB = #b | |
| local curA = 1 | |
| local curB = 1 | |
| local largestCommon = 0 | |
| local localCommon = 0 | |
| while curA <= lenA and curB <= lenB do | |
| if string.byte(a, curA) == string.byte(b, curB) then | |
| localCommon += 1 | |
| curA += 1 | |
| curB += 1 | |
| continue | |
| end | |
| largestCommon += localCommon | |
| localCommon = 0 | |
| if curA ~= curB then | |
| local max = math.max(curA, curB) | |
| curA = max | |
| curB = max | |
| end | |
| local i = 0 | |
| while i < max and (curA + i <= lenA or curB + i <= lenB) do | |
| if curA + i <= lenA and string.byte(a, curA + i) == string.byte(b, curB) then | |
| curA += i | |
| localCommon += 1 | |
| break | |
| elseif curB + i <= lenB and string.byte(a, curA) == string.byte(b, curB + i) then | |
| curB += i | |
| localCommon += 1 | |
| break | |
| end | |
| i += 1 | |
| end | |
| curA += 1 | |
| curB += 1 | |
| end | |
| largestCommon += localCommon | |
| return math.max(lenA, lenB) - largestCommon | |
| end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment