Created
December 16, 2023 17:05
-
-
Save msakai/528d33493e716bc4de3632fabbc07ba3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wall #-} | |
{-# LANGUAGE OverloadedStrings #-} | |
{-# LANGUAGE RecordWildCards #-} | |
-- https://en.wikipedia.org/wiki/Okapi_BM25 | |
module OkapiBM25 | |
( Database | |
, mkDatabase | |
, query | |
) where | |
import Data.IntMap.Strict (IntMap) | |
import qualified Data.IntMap.Strict as IntMap | |
import Data.Map.Strict (Map) | |
import qualified Data.Map.Strict as Map | |
import qualified Data.Set as Set | |
import Data.Text (Text) | |
import qualified Data.Text as Text | |
import Data.Vector (Vector) | |
import qualified Data.Vector as Vector | |
type WordId = Int | |
type WordIdMap = IntMap | |
data Params | |
= Params | |
{ k1 :: Double | |
, b :: Double | |
} | |
deriving (Eq, Ord, Show, Read) | |
data Database a w = | |
Database | |
{ documents :: [(a, WordIdMap Int)] | |
, averageDocumentLength :: Double | |
, numDocumentsWithWord :: WordIdMap Int | |
, wordTable :: Vector w | |
, wordIdTable :: Map w Int | |
} | |
deriving (Show) | |
mkDatabase :: Ord w => (a -> [w]) -> [a] -> Database a w | |
mkDatabase wordsOf xs = | |
Database | |
{ documents = docs | |
, averageDocumentLength = fromIntegral (sum [sum (IntMap.elems f) | (_, f) <- docs]) / fromIntegral (length docs) | |
, numDocumentsWithWord = IntMap.unionsWith (+) [fmap (const 1) f | (_, f) <- docs] | |
, wordTable = table1 | |
, wordIdTable = table2 | |
} | |
where | |
docs' = map (\x -> (x, wordsOf x)) xs | |
wordsSet = Set.fromList $ concat $ map snd $ docs' | |
table1 = Vector.fromList $ Set.toList wordsSet | |
table2 = Map.fromList $ zip (Set.toList wordsSet) [0..] | |
docs = [(x, IntMap.fromListWith (+) [(table2 Map.! w, 1) | w <- ws]) | (x, ws) <- docs'] | |
score :: Params -> Database a w -> [WordId] -> WordIdMap Int -> Double | |
score Params{ .. } Database{ .. } query doc = | |
sum | |
[ idf q * | |
(freq q * (k1 + 1)) | |
/ | |
(freq q + k1 * (1 - b + b * fromIntegral docLength / averageDocumentLength)) | |
| q <- query | |
] | |
where | |
n = length documents | |
idf q = log $ 1 + (fromIntegral (n - fq) + 0.5) / (fromIntegral fq + 0.5) | |
where | |
fq = numDocumentsWithWord IntMap.! q | |
freq q = fromIntegral $ IntMap.findWithDefault 0 q doc | |
docLength = sum $ IntMap.elems doc | |
query :: Ord w => Params -> Database a w -> [w] -> [(a, Double)] | |
query params db ws = [(doc, score params people ws' f) | (doc, f) <- documents db] | |
where | |
ws' = map (wordIdTable db Map.!) ws | |
-- ------------------------------------------------------------------------ | |
-- Example from | |
-- https://www.elastic.co/jp/blog/practical-bm25-part-2-the-bm25-algorithm-and-its-variables | |
-- ------------------------------------------------------------------------ | |
people :: Database Text Text | |
people = mkDatabase Text.words $ | |
[ "Shane" | |
, "Shane C" | |
, "Shane P Connelly" | |
, "Shane Connelly" | |
, "Shane Shane Connelly Connelly" | |
, "Shane Shane Shane Connelly Connelly Connelly" | |
] | |
example1 = query params people ["Shane"] | |
where | |
params = Params{ k1 = 0, b = 0.5 } | |
example2 = query params people ["Shane"] | |
where | |
params = Params{ k1 = 10, b = 0 } | |
example3 = query params people ["Shane"] | |
where | |
params = Params{ k1 = 5, b = 1 } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment