Skip to content

Instantly share code, notes, and snippets.

@mahiro72
Last active March 10, 2025 17:16
Show Gist options
  • Save mahiro72/4acf4668f20b648c72212f2d760b650c to your computer and use it in GitHub Desktop.
Save mahiro72/4acf4668f20b648c72212f2d760b650c to your computer and use it in GitHub Desktop.
Python Trie Implementation with Frequency Cache
from typing import Dict
class TrieNode:
def __init__(self):
self.children: Dict[str, TrieNode] = {}
self.is_end = False
self.freq = 0
self.top_words: Dict[str, int] = {}
class Trie:
def __init__(self, k: int = 5):
self.root = TrieNode()
self.k = k # トップk個の単語を保持する
def bulk_insert(self, word_freq: Dict[str, int]):
for word, freq in word_freq.items():
self._insert_with_freq(word, freq)
def _insert_with_freq(self, word: str, freq: int):
node = self.root
path_nodes = [node]
for char in word:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
path_nodes.append(node)
# 単語の終端をマーク
node.is_end = True
node.freq = freq
# パス上の全ノードのtop_wordsキャッシュを更新
for n in reversed(path_nodes):
merged_words = {} if len(n.top_words) == 0 else n.top_words.copy()
if n.is_end:
merged_words[word] = freq
else:
for child in n.children.values(): #自身が単語の終端ではない場合、子ノードのtop_wordsをマージする
for w, f in child.top_words.items():
merged_words[w] = f
n.top_words = dict(sorted(merged_words.items(), key=lambda x: (-x[1], x[0]))[:self.k])
def find_top_k_prefixes(self, prefix: str):
node = self.root
for char in prefix:
if char not in node.children:
return []
node = node.children[char]
# ヒットしたノードのtop_wordsのキャッシュを返す
return node.top_words
def print_trie(self):
# デバッグ用
def _print_node(node, prefix="", level=0):
indent = " " * level
if node.is_end:
print(f"{indent}{prefix} (freq={node.freq})")
print(f"{indent}{prefix} {node.top_words}")
for char, child in sorted(node.children.items()):
_print_node(child, prefix + char, level + 1)
_print_node(self.root)
if __name__ == "__main__":
word_freq = {
"tree": 10,
"true": 35,
"try": 29,
"toy": 14,
"wish": 25,
"win": 50,
}
trie = Trie()
trie.bulk_insert(word_freq)
"""
{'win': 50, 'true': 35, 'try': 29, 'wish': 25, 'toy': 14}
t {'true': 35, 'try': 29, 'toy': 14, 'tree': 10}
to {'toy': 14}
toy (freq=14)
toy {'toy': 14}
tr {'true': 35, 'try': 29, 'tree': 10}
tre {'tree': 10}
tree (freq=10)
tree {'tree': 10}
tru {'true': 35}
true (freq=35)
true {'true': 35}
try (freq=29)
try {'try': 29}
w {'win': 50, 'wish': 25}
wi {'win': 50, 'wish': 25}
win (freq=50)
win {'win': 50}
wis {'wish': 25}
wish (freq=25)
wish {'wish': 25}
"""
trie.print_trie()
# print(trie.find_top_k_prefixes("tr")) # -> {'true': 35, 'try': 29, 'tree': 10}
@mahiro72
Copy link
Author

Implementation of Trie Data Structure in Python with Frequency Cache

概要

Pythonを用いたトライデータ構造の実装です。
特徴として、単語と頻度のtopKのキャッシュを各nodeで持っており、これにより特定の単語のtopK検索時はO(1)で取得できます。ただしデータ更新が手間で、新しい単語が追加された場合、その親ノードのキャッシュを全て更新する必要があります。
そのため、頻繁に更新されるトライ構造は想定しておらず、定期的なデータバッチによってトライ構造を毎回新規作成し高速な検索システムを提供するために使う想定です。

実装の特徴

  • 各ノードはchildrenとtop_wordsをdictで保持
  • top_wordsは{word: frequency}の形式で、頻度順にソートされた上位k個の単語を保持
  • データ挿入時は子ノードから親ノードに向かってボトムアップでキャッシュを更新

ユースケース

  • 検索エンジンのオートコンプリート機能

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment