Skip to content

Instantly share code, notes, and snippets.

@cameronpcampbell
Last active May 2, 2025 00:20
Show Gist options
  • Save cameronpcampbell/9a4273102540c715a95d7ea98278d53b to your computer and use it in GitHub Desktop.
Save cameronpcampbell/9a4273102540c715a95d7ea98278d53b to your computer and use it in GitHub Desktop.
--!strict
--!nolint LocalShadow
export type TrieNode = {
is_end: boolean,
children: { [number]: TrieNode }
}
local function table_is_empty(tble: { [any]: any })
for _ in tble do return false end
return true
end
local function trie_create(): TrieNode
return { is_end = false, children = {} }
end
local function trie_insert(trie: TrieNode, str: string)
local current_node: TrieNode = trie
local str_len = #str
local current_char_idx = 1
while current_char_idx ~= str_len do
local char = string.byte(str, current_char_idx)
local next_node = current_node.children[char]
if next_node then
current_node = next_node
else
next_node = { is_end = false, children = {} }
current_node.children[char] = next_node
current_node = next_node
end
current_char_idx += 1
end
-- Adds the end char.
local char = string.byte(str, current_char_idx)
local next_node = current_node.children[char]
if next_node then
next_node.is_end = true
else
next_node = { is_end = true, children = {} }
current_node.children[char] = next_node
end
end
local function trie_contains(trie: TrieNode, str: string): boolean
local current_node: TrieNode = trie
for current_char_idx = 1, #str do
current_node = current_node.children[string.byte(str, current_char_idx)]
if not current_node then return false end
end
return current_node.is_end
end
local DEFAULT_VALID_DELIMITERS: { [number]: any } = {
[string.byte(" ")] = true
}
local function trie_scan(trie: TrieNode, str: string, valid_delimiters: { [number]: any }?): () -> string?
local valid_delimiters = valid_delimiters or DEFAULT_VALID_DELIMITERS
local root_node = trie
local root_node_children = root_node.children
local current_node = trie
local current_node_children = current_node.children
local current_match = ""
local current_byte_idx = 1
return function()
while true do
local current_byte = string.byte(str, current_byte_idx)
if current_byte then
current_byte_idx += 1
-- Returns the current match if we have reached a delimiter character.
if current_node.is_end and valid_delimiters[current_byte] then
current_node = root_node
current_node_children = root_node_children
local saved_current_match = current_match
current_match = ""
return saved_current_match
end
-- If we have found a char in the trie then we add it to the current_match.
local child_node = current_node_children[current_byte]
if child_node then
current_match ..= string.char(current_byte)
current_node = child_node
current_node_children = child_node.children
-- No char found so we reset the current_match and current_node.
else
current_match = ""
current_node = root_node
current_node_children = root_node_children
end
-- We have reached the end of the string we are scanning.
else
if current_node.is_end then
current_node = root_node
current_node_children = root_node_children
local saved_current_match = current_match
current_match = ""
return saved_current_match
end
return nil
end
end
end
end
local function trie_remove_main(parent_node: TrieNode, str: string, current_char_idx: number)
local char = string.byte(str, current_char_idx)
if char == nil then
parent_node.is_end = false
return false
else
local child_node = parent_node.children[char]
if not child_node then return false end
if trie_remove_main(child_node, str, current_char_idx + 1) then
parent_node.children[char] = nil
return (not parent_node.is_end) and table_is_empty(parent_node.children)
else
return false
end
end
end
local function trie_remove(trie: TrieNode, str: string)
return trie_remove_main(trie, str, 1)
end
local function trie_stringify_main(trie: TrieNode, nestedness: number?)
local stringified = `{if trie.is_end then "(end)" else ""}`
local children = trie.children
if table_is_empty(children) then return stringified end
stringified ..= "{"
local nestedness = nestedness or 1
local prev_tabs = string.rep(" ", (nestedness - 1) * 4)
local tabs = string.rep(" ", nestedness * 4)
local stringified = `{if trie.is_end then "(end)" else ""} \{`
for child_name, child_node in children do
local char = string.char(child_name)
stringified ..= `\n{tabs}{if char == " " then "\" \"" else char} {trie_stringify_main(child_node, nestedness + 1)},`
end
stringified ..= `\n{prev_tabs}}`
return stringified
end
local function trie_stringify(trie: TrieNode)
return trie_stringify_main(trie, 1)
end
return {
create = trie_create,
insert = trie_insert,
contains = trie_contains,
scan = trie_scan,
remove = trie_remove,
stringify = trie_stringify
}
@cameronpcampbell
Copy link
Author

cameronpcampbell commented Apr 30, 2025

Example Usage:

local trie = require("path/to/trie")

local my_trie = trie.create()

trie.insert(my_trie, "hello")
trie.insert(my_trie, "hello_world")

print(trie.contains(my_trie, "foo")) -- false
print(trie.contains(my_trie, "hello")) -- true

trie.remove(my_trie, "hello")
print(trie.contains(my_trie, "hello")) -- false

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment