Created
December 8, 2011 13:40
-
-
Save ehamberg/1447021 to your computer and use it in GitHub Desktop.
Implementation of a BK-Tree in Haskell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import qualified Data.Map as M | |
import Control.Applicative | |
import Data.Maybe (mapMaybe) | |
-- A BK-Tree is has a root word and more trees connected to it with branches of | |
-- lengths equal to the Levenshtein distance between their root words (i.e. an | |
-- n-ary tree). | |
data BKTree s = BKTree s (M.Map Int (BKTree s)) | Empty deriving (Show) | |
-- Inserting a word is done by inserting it along a branch of lenght | |
-- [Levenshtein distance] between the word to be inserted and the root. If | |
-- there is a child node there, change focus to that child and continue the | |
-- operation. | |
insertWord :: BKTree String -> String -> BKTree String | |
insertWord Empty newWord = BKTree newWord M.empty | |
insertWord (BKTree rootWord ts) newWord = | |
case M.lookup d ts of | |
Nothing -> BKTree rootWord (M.insert d (BKTree newWord M.empty) ts) | |
Just c -> BKTree rootWord (M.adjust (flip insertWord $ newWord) d ts) | |
where d = levenshtein rootWord newWord | |
-- Querying the tree consists of checking the Levenshtein distance for the | |
-- current node, then recursively checking all child nodes connected with a | |
-- branch of length [(d-n),(d+n)] | |
query :: Int -> String -> BKTree String -> [String] | |
query n queryWord (BKTree rootWord ts) = if d <= n | |
then rootWord:ms | |
else ms | |
where -- Levenshtein distance from query word to this node's word | |
d = levenshtein rootWord queryWord | |
-- find child nodes in the range [(d-n),(d+n)] ... | |
cs = mapMaybe (`M.lookup` ts) [(d-n)..(d+n)] | |
-- ... recursively query these child nodes and concatenate the results | |
ms = concatMap (query n queryWord) cs | |
-- Levenshtein distance calculation function taken from | |
-- en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance | |
levenshtein :: (Eq t) => [t] -> [t] -> Int | |
levenshtein sa sb = last $ foldl transform [0..length sa] sb | |
where transform xs@(x:xs') c = scanl compute (x+1) (zip3 sa xs xs') | |
where compute z (c', x, y) = minimum [y+1, z+1, x + fromEnum (c' /= c)] | |
ask :: BKTree String -> IO () | |
ask bk_tree = do | |
putStrLn "Enter query word: " | |
queryWord <- getLine | |
putStrLn "Enter max distance: " | |
dist <- read <$> getLine | |
print $ query dist queryWord bk_tree | |
ask bk_tree | |
main :: IO () | |
main = do | |
-- read dictionary file, skipping comments | |
dic <- (filter (not . comment) . lines) <$> readFile "dictionary.txt" | |
-- build BK-Tree | |
let bk_tree = foldl (insertWord) Empty dic | |
ask bk_tree | |
where comment [] = True | |
comment ('#':_) = True | |
comment _ = False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment