Skip to content

Instantly share code, notes, and snippets.

@ehamberg
Created December 8, 2011 13:40
Show Gist options
  • Save ehamberg/1447021 to your computer and use it in GitHub Desktop.
Save ehamberg/1447021 to your computer and use it in GitHub Desktop.
Implementation of a BK-Tree in Haskell
import qualified Data.Map as M
import Control.Applicative
import Data.Maybe (mapMaybe)
-- A BK-Tree is has a root word and more trees connected to it with branches of
-- lengths equal to the Levenshtein distance between their root words (i.e. an
-- n-ary tree).
data BKTree s = BKTree s (M.Map Int (BKTree s)) | Empty deriving (Show)
-- Inserting a word is done by inserting it along a branch of lenght
-- [Levenshtein distance] between the word to be inserted and the root. If
-- there is a child node there, change focus to that child and continue the
-- operation.
insertWord :: BKTree String -> String -> BKTree String
insertWord Empty newWord = BKTree newWord M.empty
insertWord (BKTree rootWord ts) newWord =
case M.lookup d ts of
Nothing -> BKTree rootWord (M.insert d (BKTree newWord M.empty) ts)
Just c -> BKTree rootWord (M.adjust (flip insertWord $ newWord) d ts)
where d = levenshtein rootWord newWord
-- Querying the tree consists of checking the Levenshtein distance for the
-- current node, then recursively checking all child nodes connected with a
-- branch of length [(d-n),(d+n)]
query :: Int -> String -> BKTree String -> [String]
query n queryWord (BKTree rootWord ts) = if d <= n
then rootWord:ms
else ms
where -- Levenshtein distance from query word to this node's word
d = levenshtein rootWord queryWord
-- find child nodes in the range [(d-n),(d+n)] ...
cs = mapMaybe (`M.lookup` ts) [(d-n)..(d+n)]
-- ... recursively query these child nodes and concatenate the results
ms = concatMap (query n queryWord) cs
-- Levenshtein distance calculation function taken from
-- en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance
levenshtein :: (Eq t) => [t] -> [t] -> Int
levenshtein sa sb = last $ foldl transform [0..length sa] sb
where transform xs@(x:xs') c = scanl compute (x+1) (zip3 sa xs xs')
where compute z (c', x, y) = minimum [y+1, z+1, x + fromEnum (c' /= c)]
ask :: BKTree String -> IO ()
ask bk_tree = do
putStrLn "Enter query word: "
queryWord <- getLine
putStrLn "Enter max distance: "
dist <- read <$> getLine
print $ query dist queryWord bk_tree
ask bk_tree
main :: IO ()
main = do
-- read dictionary file, skipping comments
dic <- (filter (not . comment) . lines) <$> readFile "dictionary.txt"
-- build BK-Tree
let bk_tree = foldl (insertWord) Empty dic
ask bk_tree
where comment [] = True
comment ('#':_) = True
comment _ = False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment