Skip to content

Instantly share code, notes, and snippets.

@ljnmedium
Created September 29, 2023 07:32
Show Gist options
  • Save ljnmedium/ca7f0d916fd1e1c31c1476d80216038c to your computer and use it in GitHub Desktop.
Save ljnmedium/ca7f0d916fd1e1c31c1476d80216038c to your computer and use it in GitHub Desktop.
refactoring_1.py
from openai import Embedding
from pinecone_text.sparse import BM25Encoder
EMBEDDING_MODEL = "text-embedding-ada-002"
SPARSE_MODEL_FILE_NAME = "bm25_values.json"
class Embedding_Model(Embedding):
def __init__(self, model_name):
self.engine = model_name
def encode(self, text):
if isinstance(text, List):
text = list(map(lambda x: x.replace("\n", " "), text))
texts_embedded = Embedding.create(input = text, model=self.engine)['data']
texts_embedded = [text_embedded['embedding'] for text_embedded in texts_embedded]
return texts_embedded
text = text.replace("\n", " ")
text_embedded = Embedding.create(input = [text], model=self.engine)['data'][0]['embedding']
return text_embedded
class Sparse_Embedding_Model(BM25Encoder):
def __init__(self, model_file_name):
if model_file_name:
self.model_name = model_file_name
self.model = BM25Encoder().load(model_file_name)
else:
self.model = BM25Encoder.default()
def finetune(self, texts: List[str], model_file_name) :
self.model.fit(texts).dump(model_file_name)
def load_model(self, model_file_name):
self.model = BM25Encoder().load(model_file_name)
def encode(self, text):
text_embedded = self.model.encode_documents(text)
return text_embedded
def query_encode(self, query:str):
return self.model.encode_queries(query)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment