Created
September 29, 2023 07:32
-
-
Save ljnmedium/ca7f0d916fd1e1c31c1476d80216038c to your computer and use it in GitHub Desktop.
refactoring_1.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from openai import Embedding | |
from pinecone_text.sparse import BM25Encoder | |
EMBEDDING_MODEL = "text-embedding-ada-002" | |
SPARSE_MODEL_FILE_NAME = "bm25_values.json" | |
class Embedding_Model(Embedding): | |
def __init__(self, model_name): | |
self.engine = model_name | |
def encode(self, text): | |
if isinstance(text, List): | |
text = list(map(lambda x: x.replace("\n", " "), text)) | |
texts_embedded = Embedding.create(input = text, model=self.engine)['data'] | |
texts_embedded = [text_embedded['embedding'] for text_embedded in texts_embedded] | |
return texts_embedded | |
text = text.replace("\n", " ") | |
text_embedded = Embedding.create(input = [text], model=self.engine)['data'][0]['embedding'] | |
return text_embedded | |
class Sparse_Embedding_Model(BM25Encoder): | |
def __init__(self, model_file_name): | |
if model_file_name: | |
self.model_name = model_file_name | |
self.model = BM25Encoder().load(model_file_name) | |
else: | |
self.model = BM25Encoder.default() | |
def finetune(self, texts: List[str], model_file_name) : | |
self.model.fit(texts).dump(model_file_name) | |
def load_model(self, model_file_name): | |
self.model = BM25Encoder().load(model_file_name) | |
def encode(self, text): | |
text_embedded = self.model.encode_documents(text) | |
return text_embedded | |
def query_encode(self, query:str): | |
return self.model.encode_queries(query) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment