Created
February 10, 2023 20:39
-
-
Save alexlimh/61cf2a16e352dcad74ac93e48283ba98 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
import argparse | |
import numpy as np | |
import pickle | |
import collections | |
import jsonlines | |
import torch | |
import glob | |
import scipy | |
from functools import partial | |
from multiprocessing import Pool | |
from tqdm import tqdm | |
from pyserini.search.lucene import LuceneImpactSearcher | |
from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap | |
try: | |
from dpr_scale.retriever_ext import scatter as c_scatter | |
except ImportError: | |
raise ImportError( | |
'Cannot import scatter module.' | |
' Make sure you have compiled the retriever extension.' | |
) | |
def load_file(path, i): | |
data = torch.load(path) | |
return (data, i) | |
def maxsim(entry): | |
q_embed, d_embeds, d_lens, qid, scores, docids = entry | |
if len(d_embeds) == 0: | |
return qid, scores, docids | |
d_embeds = scipy.sparse.vstack(d_embeds).transpose() # (LD x 1000) x D | |
max_scores = (q_embed@d_embeds).todense() # LQ x (LD x 1000) | |
scores = [] | |
start = 0 | |
for d_len in d_lens: | |
scores.append(max_scores[:, start:start+d_len].max(1).sum()) | |
start += d_len | |
scores, docids = list(zip(*sorted(list(zip(scores, docids)), key=lambda x: -x[0]))) | |
return qid, scores, docids | |
class LuceneMultiTermSearcher(LuceneImpactSearcher): | |
def __init__(self, query_path, query_embedding_path, sparse_query_path, sparse_corpus_path, corpus_path, weight_threshold, threads, id2idx_path=None, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.weight_threshold = weight_threshold | |
self.threads = threads | |
self.queries = {} | |
print("Loading queries...") | |
if query_path.split(".")[-1] == "tsv": | |
with open(query_path) as f: | |
lines = f.readlines() | |
for line in lines: | |
qid, query = line.strip().split("\t") | |
self.queries[qid] = query | |
elif query_path.split(".")[-1] == "jsonl": | |
with jsonlines.open(query_path) as f: | |
for line in f: | |
self.queries[line["question_id"]] = line["question"] | |
print("Loading query embeddings...") | |
self.query_embeddings = self.query_preprocess(query_embedding_path) | |
print("Loading corpus...") | |
with open(corpus_path) as f: | |
lines = f.readlines() | |
self.corpus_len = len(lines) - 1 # exclude header | |
self.sparse_q_vecs = None | |
self.sparse_vecs = None | |
if sparse_query_path is not None and os.path.exists(sparse_query_path): | |
self.pool = Pool(threads) | |
print("Loading sparse query vectors for fast reranking...") | |
sparse_query_range_path = os.path.join(sparse_query_path, "sparse_range.pkl") | |
with open(sparse_query_range_path, "rb") as f: | |
sparse_q_ranges = pickle.load(f) | |
sparse_query_vec_path = os.path.join(sparse_query_path, "sparse_vec.npz") | |
sparse_q_vecs = scipy.sparse.load_npz(sparse_query_vec_path) | |
sparse_q_vecs_scatter = [] | |
for start, end in sparse_q_ranges: | |
sparse_q_vecs_scatter.append(sparse_q_vecs[start:end]) | |
if id2idx_path is None: | |
self.sparse_q_vecs = {k:v for k, v in zip(list(self.query_embeddings.keys()), sparse_q_vecs_scatter)} | |
else: | |
with open(id2idx_path, "rb") as f: | |
id2idx = pickle.load(f) | |
self.sparse_q_vecs = {k:sparse_q_vecs_scatter[id2idx[k]] for k in self.query_embeddings.keys()} | |
print("Loading sparse corpus vectors for fast reranking...") | |
sparse_corpus_range_path = os.path.join(sparse_corpus_path, "sparse_range.pkl") | |
with open(sparse_corpus_range_path, "rb") as f: | |
self.sparse_ranges = pickle.load(f) | |
sparse_corpus_vec_path = os.path.join(sparse_corpus_path, "sparse_vec.npz") | |
sparse_vecs = scipy.sparse.load_npz(sparse_corpus_vec_path) | |
self.sparse_vecs = [] | |
for start, end in tqdm(self.sparse_ranges): | |
self.sparse_vecs.append(sparse_vecs[start:end]) | |
def query_preprocess(self, embedding_path): | |
upper_embeddings = collections.defaultdict(dict) | |
with jsonlines.open(embedding_path) as f: | |
for line in f: | |
if len(line["vector"]) > 0: | |
topic_pos_id = line["id"] | |
splits = topic_pos_id.split("_") | |
pos = splits[-1] | |
topic_id = "_".join(splits[:-1]) | |
if topic_id in self.queries: | |
for term, weight in line["vector"].items(): | |
if weight > self.weight_threshold: | |
upper_embeddings[topic_id][term] = upper_embeddings[topic_id].get(term, 0) + weight | |
return upper_embeddings | |
def batch_search(self, topk, threads, batch_size): | |
query_lst = JArrayList() | |
qid_lst = JArrayList() | |
qids = [] | |
ranking = {} | |
iterator = list(self.query_embeddings.items()) | |
count = 0 | |
print("Searching...") | |
latency = 0 | |
for i, entry in tqdm(list(enumerate(iterator))): | |
qid, vector = entry | |
jquery = JHashMap() | |
for token, weight in vector.items(): | |
if weight > self.weight_threshold and token in self.idf and self.idf[token] > self.min_idf: | |
jquery.put(token, JFloat(weight)) | |
query_lst.add(jquery) | |
qid_lst.add(qid) | |
qids.append(qid) | |
count += 1 | |
if count == batch_size or i == len(iterator) - 1: | |
tic = time.perf_counter() | |
raw_results = self.object.batch_search(query_lst, qid_lst, topk, threads) | |
results = {r.getKey(): r.getValue() for r in raw_results.entrySet().toArray()} | |
all_scores = [] | |
all_docids = [] | |
for qid in qids: | |
hits = results[qid] | |
docids = [] | |
scores = [] | |
for hit in hits: | |
docids.append(int(hit.docid)) | |
scores.append(hit.score) | |
all_scores.append(scores) | |
all_docids.append(docids) | |
if self.sparse_vecs is not None: | |
qids, all_scores, all_docids = self.fast_rerank(qids, all_scores, all_docids) | |
for qid, scores, docids in zip(qids, all_scores, all_docids): | |
ranking[qid] = (scores, docids) | |
toc = time.perf_counter() | |
latency += toc - tic | |
query_lst = JArrayList() | |
qid_lst = JArrayList() | |
qids = [] | |
count = 0 | |
if self.sparse_vecs is not None: | |
self.pool.close() | |
print(f"Average search latency {latency/len(iterator)*1000:.2f}ms/query") | |
return ranking | |
def fast_rerank(self, qids, all_scores, all_docids): | |
all_q_embeds = [] | |
all_d_embeds = [] | |
all_d_lens = [] | |
for qid, scores, docids in zip(qids, all_scores, all_docids): | |
all_q_embeds.append(self.sparse_q_vecs[qid]) | |
d_embeds = [] | |
d_lens = [] | |
for docid in docids: | |
start, end = self.sparse_ranges[int(docid)] | |
d_embeds.append(self.sparse_vecs[int(docid)]) | |
d_lens.append(end-start) | |
all_d_embeds.append(d_embeds) | |
all_d_lens.append(d_lens) | |
entries = list(zip(all_q_embeds, all_d_embeds, all_d_lens, qids, all_scores, all_docids)) | |
results = self.pool.map(maxsim, entries) | |
qids, all_scores, all_docids = list(zip(*results)) | |
return qids, all_scores, all_docids | |
def main(args): | |
searcher = LuceneMultiTermSearcher(args.query_path, | |
args.query_embedding_path, | |
args.sparse_query_vec_path, | |
args.sparse_corpus_vec_path, | |
args.corpus_path, | |
args.weight_threshold, | |
id2idx_path=args.id2idx_path, | |
threads=args.threads, | |
index_dir=args.index, | |
query_encoder=None, | |
min_idf=args.min_idf) | |
ranking = searcher.batch_search(args.topk, args.threads, args.batch_size) | |
i2d = [] | |
if args.idx2id_path is not None and os.path.exists(args.idx2id_path): | |
with open(args.idx2id_path) as f: | |
lines = f.readlines() | |
for line in lines: | |
i2d.append(line.strip()) | |
trec_reults = [] | |
for topic_id, (top_scores, top_indices) in ranking.items(): | |
for rank, (score, doc_id) in enumerate(list(zip(top_scores, top_indices))[:args.out_topk]): | |
if len(i2d) == 0: | |
trec_reults.append(f"{topic_id} Q0 {doc_id} {rank+1} {score:.6f} Anserini\n") | |
else: | |
trec_reults.append(f"{topic_id} Q0 {i2d[doc_id]} {rank+1} {score:.6f} Anserini\n") | |
print(f"Writing output to {args.output_path}") | |
os.makedirs(args.output_path, exist_ok=True) | |
with open(os.path.join(args.output_path, f"retrieval.trec"), "w") as g: | |
g.writelines(trec_reults) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Keeps only queries in a qrels file.') | |
parser.add_argument('--index', required=True, help='MS MARCO tsv qrels file.') | |
parser.add_argument('--query_path', required=True, help='Queries file.') | |
parser.add_argument('--id2idx_path', default=None, help='Queries file.') | |
parser.add_argument('--query_embedding_path', required=True, help='MS MARCO tsv qrels file.') | |
parser.add_argument('--sparse_query_vec_path', default=None, help='MS MARCO tsv qrels file.') | |
parser.add_argument('--sparse_corpus_vec_path', default=None, help='MS MARCO tsv qrels file.') | |
parser.add_argument('--idx2id_path', default=None, help='MS MARCO tsv qrels file.') | |
parser.add_argument('--corpus_path', required=True, help='Queries file.') | |
parser.add_argument('--output_path', required=True, help='Output queries file.') | |
parser.add_argument('--topk', type=int, default=1000, help='Output queries file.') | |
parser.add_argument('--out_topk', type=int, default=1000, help='Output queries file.') | |
parser.add_argument('--threads', type=int, default=1, help='Output queries file.') | |
parser.add_argument('--batch_size', type=int, default=128, help='Output queries file.') | |
parser.add_argument('--weight_threshold', type=float, default=0.0, help='Output queries file.') | |
parser.add_argument('--min_idf', type=float, default=0.0, help='Output queries file.') | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment