Created
June 26, 2024 06:13
-
-
Save malteos/178a1b77ac362cd7857a054e2d9c07cb to your computer and use it in GitHub Desktop.
Run BM25 baseline on MTEB retrieval tasks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Evaluate BM25 on MTEB tasks | |
Usage: | |
python bm25.py -t <task name> --output_folder=./data/results | |
Notes: | |
- https://github.com/xhluca/bm25s (promissing implememntation) | |
- https://github.com/beir-cellar/beir/blob/main/examples/retrieval/evaluation/lexical/evaluate_bm25.py | |
- https://colab.research.google.com/drive/1HfutiEhHMJLXiWGT8pcipxT5L2TpYEdt?usp=sharing#scrollTo=nqotyXuIBPt6 | |
Requirements: | |
pip install "bm25s[full]" PyStemmer beir | |
""" | |
import argparse | |
import json | |
import logging | |
import os | |
from pathlib import Path | |
from time import time | |
from typing import List, Optional, Union | |
import bm25s | |
import Stemmer | |
from beir.retrieval.evaluation import EvaluateRetrieval | |
from mteb.abstasks import AbsTaskRetrieval | |
from mteb.evaluation import MTEB | |
from mteb.evaluation.evaluators.RetrievalEvaluator import DenseRetrievalExactSearch | |
logging.basicConfig( | |
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
class BM25Search(DenseRetrievalExactSearch): | |
"""Override dense retrieval with BM25 search""" | |
def __init__( | |
self, | |
previous_results: str = None, | |
stopwords: str = "en", | |
stemmer_language: Optional[str] = "english", | |
**kwargs, | |
): | |
super().__init__( | |
model=None, | |
batch_size=1, | |
corpus_chunk_size=1, | |
previous_results=previous_results, | |
**kwargs, | |
) | |
self.stopwords = stopwords | |
# optional: create a stemmer | |
self.stemmer = Stemmer.Stemmer(stemmer_language) if stemmer_language else None | |
def search( | |
self, | |
corpus: dict[str, dict[str, str]], | |
queries: dict[str, Union[str, List[str]]], | |
top_k: int, | |
score_function: str, | |
return_sorted: bool = False, | |
**kwargs, | |
) -> dict[str, dict[str, float]]: | |
logger.info("Encoding Corpus...") | |
corpus_ids = list(corpus.keys()) | |
corpus_with_ids = [{"doc_id": cid, **corpus[cid]} for cid in corpus_ids] | |
corpus_texts = [ | |
"\n".join([doc["title"], doc["text"]]) for doc in corpus_with_ids | |
] # concatenate all document values (title, text, ...) | |
encoded_corpus = self.encode(corpus_texts) | |
logger.info( | |
f"Indexing Corpus... {len(encoded_corpus.ids):,} documents, {len(encoded_corpus.vocab):,} vocab" | |
) | |
# Create the BM25 model and index the corpus | |
retriever = bm25s.BM25() | |
retriever.index(encoded_corpus) | |
logger.info("Encoding Queries...") | |
query_ids = list(queries.keys()) | |
self.results = {qid: {} for qid in query_ids} | |
queries_texts = [queries[qid] for qid in queries] | |
query_token_strs = self.encode(queries_texts, return_ids=False) | |
logger.info(f"Retrieving Results... {len(queries):,} queries") | |
queries_results, queries_scores = retriever.retrieve( | |
query_token_strs, corpus=corpus_with_ids, k=top_k | |
) | |
# Iterate over queries | |
for qi, qid in enumerate(query_ids): | |
doc_id_to_score = {} | |
query_results = queries_results[qi] | |
scores = queries_scores[qi] | |
doc_id_to_score = {} | |
# Iterate over results | |
for ri in range(len(query_results)): | |
doc = query_results[ri] | |
score = scores[ri] | |
doc_id = doc["doc_id"] | |
doc_id_to_score[doc_id] = float(score) | |
self.results[qid] = doc_id_to_score | |
return self.results | |
def encode(self, texts: List[str], **kwargs): | |
"""Encode input text as term vectors""" | |
return bm25s.tokenize( | |
texts, stopwords=self.stopwords, stemmer=self.stemmer, **kwargs | |
) | |
class BM25MTEB(MTEB): | |
"""Override eval methods from parent class""" | |
def select_tasks(self, **kwargs): | |
"""Select the tasks to be evaluated.""" | |
super().select_tasks(**kwargs) | |
# Get only retrieval tasks | |
self.tasks = [t for t in self.tasks if isinstance(t, AbsTaskRetrieval)] | |
def _run_eval(self, task, model, split, output_folder, **kwargs): | |
if model is not None: | |
raise ValueError("BM25 does not need a model") | |
if not isinstance(task, AbsTaskRetrieval): | |
raise ValueError( | |
"Only retrieval tasks that inherit `AbsTaskRetrieval` from can be evaluated!" | |
) | |
tick = time() | |
results = self.evaluate_task(task, split, output_folder=output_folder, **kwargs) | |
tock = time() | |
return results, tick, tock | |
def _evaluate_subset( | |
self, | |
corpus, | |
queries, | |
relevant_docs, | |
hf_subset: str, | |
main_score: str, | |
k_values=[1, 3, 5, 10, 20, 100, 1000], | |
**kwargs, | |
): | |
start_time = time() | |
# Retrieve and evaluate with BM25 search | |
model = BM25Search() | |
retriever = EvaluateRetrieval(retriever=model) | |
results = retriever.retrieve(corpus, queries) | |
end_time = time() | |
logger.info( | |
"Time taken to retrieve: {:.2f} seconds".format(end_time - start_time) | |
) | |
if kwargs.get("save_predictions", False): | |
output_folder = Path(kwargs.get("output_folder", "results")) | |
if not os.path.isdir(output_folder): | |
os.makedirs(output_folder) | |
top_k = kwargs.get("top_k", None) | |
if top_k is not None: | |
for qid in list(results.keys()): | |
doc_ids = set( | |
sorted( | |
results[qid], key=lambda x: results[qid][x], reverse=True | |
)[:top_k] | |
) | |
results[qid] = { | |
k: v for k, v in results[qid].items() if k in doc_ids | |
} | |
qrels_save_path = ( | |
output_folder | |
/ f"{self.metadata_dict['name']}_{hf_subset}_predictions.json" | |
) | |
with open(qrels_save_path, "w") as f: | |
json.dump(results, f) | |
ndcg, _map, recall, precision = retriever.evaluate( | |
relevant_docs, | |
results, | |
k_values, | |
ignore_identical_ids=kwargs.get("ignore_identical_ids", True), | |
) | |
mrr = retriever.evaluate_custom(relevant_docs, results, k_values, "mrr") | |
scores = { | |
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, | |
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, | |
**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, | |
**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, | |
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()}, | |
} | |
# task._add_main_score(scores) | |
scores["main_score"] = scores[main_score] | |
return scores | |
def evaluate_task(self, task, split="test", **kwargs): | |
"""Evaluate a specific task""" | |
scores = {} | |
hf_subsets = ( | |
[l for l in task.hf_subsets] | |
if (task.is_multilingual or task.is_crosslingual) | |
else ["default"] | |
) | |
for hf_subset in hf_subsets: | |
logger.info(f"Subset: {hf_subset}") | |
if hf_subset == "default": | |
corpus, queries, relevant_docs = ( | |
task.corpus[split], | |
task.queries[split], | |
task.relevant_docs[split], | |
) | |
else: | |
corpus, queries, relevant_docs = ( | |
task.corpus[hf_subset][split], | |
task.queries[hf_subset][split], | |
task.relevant_docs[hf_subset][split], | |
) | |
scores[hf_subset] = self._evaluate_subset( | |
corpus, | |
queries, | |
relevant_docs, | |
hf_subset, | |
task.metadata.main_score, | |
**kwargs, | |
) | |
return scores | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--task_types", | |
nargs="+", | |
type=str, | |
default=None, | |
help="List of task types (Clustering, Retrieval..) to be evaluated. If None, all tasks will be evaluated", | |
) | |
parser.add_argument( | |
"--task_categories", | |
nargs="+", | |
type=str, | |
default=None, | |
help="List of task categories (s2s, p2p..) to be evaluated. If None, all tasks will be evaluated", | |
) | |
parser.add_argument( | |
"-t", | |
"--tasks", | |
nargs="+", | |
type=str, | |
default=None, | |
help="List of tasks to be evaluated. If specified, the other arguments are ignored.", | |
) | |
parser.add_argument( | |
"-l", | |
"--task-langs", | |
nargs="*", | |
type=str, | |
default=None, | |
help="List of languages to be evaluated. if not set, all languages will be evaluated.", | |
) | |
parser.add_argument( | |
"--seed", type=int, default=42, help="Random seed for computation" | |
) | |
parser.add_argument( | |
"--output_folder", | |
type=str, | |
default=None, | |
help="Output directory for results. Will default to results/{model_name} if not set.", | |
) | |
parser.add_argument( | |
"-v", "--verbosity", type=int, default=2, help="Verbosity level" | |
) | |
parser.add_argument( | |
"--co2_tracker", | |
type=bool, | |
default=False, | |
help="Enable CO₂ tracker, disabled by default", | |
) | |
## evaluation params | |
parser.add_argument( | |
"--eval_splits", | |
nargs="+", | |
type=str, | |
default=None, | |
help="Evaluation splits to use (train, dev, test..). If None, all splits will be used", | |
) | |
## display tasks | |
parser.add_argument( | |
"--available_tasks", | |
action="store_true", | |
default=False, | |
help="Display the available tasks", | |
) | |
# TODO: check what prams are useful to add | |
args = parser.parse_args() | |
# set logging based on verbosity level | |
if args.verbosity == 0: | |
logging.getLogger("mteb").setLevel(logging.CRITICAL) | |
elif args.verbosity == 1: | |
logging.getLogger("mteb").setLevel(logging.WARNING) | |
elif args.verbosity == 2: | |
logging.getLogger("mteb").setLevel(logging.INFO) | |
elif args.verbosity == 3: | |
logging.getLogger("mteb").setLevel(logging.DEBUG) | |
logger.info("Running with parameters: %s", args) | |
if args.available_tasks: | |
BM25MTEB.mteb_tasks() | |
return | |
eval = BM25MTEB( | |
task_categories=args.task_categories, | |
task_types=args.task_types, | |
task_langs=args.task_langs, | |
tasks=args.tasks, | |
) | |
eval.run( | |
model=None, | |
verbosity=args.verbosity, | |
output_folder=args.output_folder, | |
eval_splits=args.eval_splits, | |
co2_tracker=args.co2_tracker, | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment