Skip to content

Instantly share code, notes, and snippets.

@ljnmedium
Created September 29, 2023 07:34
Show Gist options
  • Save ljnmedium/dc44950c7cc96e55aace72d47882b5d6 to your computer and use it in GitHub Desktop.
Save ljnmedium/dc44950c7cc96e55aace72d47882b5d6 to your computer and use it in GitHub Desktop.
refactoring_2.py
from pinecone import Index
class Retreiver(Index):
def __init__(self, index_name, embedd_openai_model, sparse_model_file_name, dimension=1536, metric="dotproduct"):
self.index_name = index_name
# Create index
if index_name not in pinecone.list_indexes():
pinecone.create_index(name=index_name,
dimension=dimension,
metric=metric)
self.index = pinecone.Index(index_name)
# Create embedding model
if embedd_openai_model:
self.embedd_model = Embedding_Model(embedd_openai_model)
# Create Spared embedding model
if sparse_model_file_name:
self.sparsed_model = Sparse_Embedding_Model(sparse_model_file_name)
def delete_index(self, index_name):
pinecone.delete_index(index_name)
logger.logger.info(pinecone.list_indexes())
def reset_index_namespace(self, namespace:str):
self.index.delete(delete_all=True, namespace=namespace)
def upsert_batch(self, data, batch_size= 10, namespace=None):
for i in tqdm(range(0, len(data), batch_size)):
# find end of batch
i_end = min(i+batch_size, len(data))
# extract batch
batch = data[i:i_end]
# generate embeddings for batch
if self.embedd_model:
values = self.embedd_model.encode([b['content'] for b in batch])
#emb = retriever.encode(batch["context"].tolist()).tolist()
if self.sparsed_model:
sparse_values = self.sparsed_model.encode([b['content'] for b in batch])
else:
sparse_values = [None]*(i_end-i)
# get metadata
metas = [b['metadata'] for b in batch]
# create unique IDs
ids = [str(b['metadata']['id']) for b in batch]
# add all to upsert list
to_upsert = [{'id': i, 'values': v, 'metadata':m , 'sparse_values': sv} for (i,v,m,sv) in zip(ids,values, metas, sparse_values)]
# upsert/insert these records to pinecone
self.index.upsert(vectors=to_upsert, namespace=namespace)
logger.logger.info(str(self.index.describe_index_stats()))
@staticmethod
def hybrid_score_norm(dense, sparse, alpha: float):
# alpha -> 0: more important for key works search
if alpha < 0 or alpha > 1:
raise ValueError("Alpha must be between 0 and 1")
hs = {
'indices': sparse['indices'],
'values': [v * (1 - alpha) for v in sparse['values']]
}
return [v * alpha for v in dense], hs
def retreive_with_query(self, query,top_k= 3, alpha=None, namespace=None, filter_:Dict= None):
query_embedd = self.embedd_model.encode(query)
query_sparse = self.sparsed_model.encode(query)
if alpha :
query_embedd, query_sparse = self.hybrid_score_norm(query_embedd, query_sparse, alpha)
# get relevant contexts
query_response = self.index.query(
namespace=namespace,
top_k=top_k,
vector=query_embedd,
sparse_vector=query_sparse,
filter=filter_
)
# return the ids for matched contexts
results = [
int(x['id']) for x in query_response['matches']
]
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment