Created
September 29, 2023 07:34
-
-
Save ljnmedium/dc44950c7cc96e55aace72d47882b5d6 to your computer and use it in GitHub Desktop.
refactoring_2.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pinecone import Index | |
class Retreiver(Index): | |
def __init__(self, index_name, embedd_openai_model, sparse_model_file_name, dimension=1536, metric="dotproduct"): | |
self.index_name = index_name | |
# Create index | |
if index_name not in pinecone.list_indexes(): | |
pinecone.create_index(name=index_name, | |
dimension=dimension, | |
metric=metric) | |
self.index = pinecone.Index(index_name) | |
# Create embedding model | |
if embedd_openai_model: | |
self.embedd_model = Embedding_Model(embedd_openai_model) | |
# Create Spared embedding model | |
if sparse_model_file_name: | |
self.sparsed_model = Sparse_Embedding_Model(sparse_model_file_name) | |
def delete_index(self, index_name): | |
pinecone.delete_index(index_name) | |
logger.logger.info(pinecone.list_indexes()) | |
def reset_index_namespace(self, namespace:str): | |
self.index.delete(delete_all=True, namespace=namespace) | |
def upsert_batch(self, data, batch_size= 10, namespace=None): | |
for i in tqdm(range(0, len(data), batch_size)): | |
# find end of batch | |
i_end = min(i+batch_size, len(data)) | |
# extract batch | |
batch = data[i:i_end] | |
# generate embeddings for batch | |
if self.embedd_model: | |
values = self.embedd_model.encode([b['content'] for b in batch]) | |
#emb = retriever.encode(batch["context"].tolist()).tolist() | |
if self.sparsed_model: | |
sparse_values = self.sparsed_model.encode([b['content'] for b in batch]) | |
else: | |
sparse_values = [None]*(i_end-i) | |
# get metadata | |
metas = [b['metadata'] for b in batch] | |
# create unique IDs | |
ids = [str(b['metadata']['id']) for b in batch] | |
# add all to upsert list | |
to_upsert = [{'id': i, 'values': v, 'metadata':m , 'sparse_values': sv} for (i,v,m,sv) in zip(ids,values, metas, sparse_values)] | |
# upsert/insert these records to pinecone | |
self.index.upsert(vectors=to_upsert, namespace=namespace) | |
logger.logger.info(str(self.index.describe_index_stats())) | |
@staticmethod | |
def hybrid_score_norm(dense, sparse, alpha: float): | |
# alpha -> 0: more important for key works search | |
if alpha < 0 or alpha > 1: | |
raise ValueError("Alpha must be between 0 and 1") | |
hs = { | |
'indices': sparse['indices'], | |
'values': [v * (1 - alpha) for v in sparse['values']] | |
} | |
return [v * alpha for v in dense], hs | |
def retreive_with_query(self, query,top_k= 3, alpha=None, namespace=None, filter_:Dict= None): | |
query_embedd = self.embedd_model.encode(query) | |
query_sparse = self.sparsed_model.encode(query) | |
if alpha : | |
query_embedd, query_sparse = self.hybrid_score_norm(query_embedd, query_sparse, alpha) | |
# get relevant contexts | |
query_response = self.index.query( | |
namespace=namespace, | |
top_k=top_k, | |
vector=query_embedd, | |
sparse_vector=query_sparse, | |
filter=filter_ | |
) | |
# return the ids for matched contexts | |
results = [ | |
int(x['id']) for x in query_response['matches'] | |
] | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment