Created
September 29, 2023 07:36
-
-
Save ljnmedium/385b2e0cfbe97e5a6836a31cd9c7edd4 to your computer and use it in GitHub Desktop.
chat_doc.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
OPENAI_MODEL ='text-davinci-003' | |
class ChatDoc(): | |
def __init__(self, model_name:str = OPENAI_MODEL, **model_params): | |
self.model = model_name | |
self.engine_params = dict( | |
model = model_name, | |
temperature=0, | |
max_tokens=400, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
stop=None | |
) | |
if model_params: | |
self.engine_params.update(model_params) | |
@staticmethod | |
def connect_openai_api(): | |
load_dotenv() | |
if 'OPENAI_API_KEY' in os.environ : | |
openai.api_key = os.getenv('OPENAI_API_KEY') | |
if openai.Model.list(): | |
return True | |
return False | |
def prompt_completion(self, prompt: str, verbose=False): | |
if verbose: | |
logger.logger.info(prompt) | |
return self.complete(prompt) | |
def prompt_with_context(self, query: str, contexts: List[str], limit_context: int=5000, verbose=False): | |
# build our prompt with the retrieved contexts included | |
prompt_start = ( | |
"Answer the question based on the context below.\n\n"+ | |
"Context:\n" | |
) | |
prompt_end = ( | |
f"\n\nQuestion: {query}\nAnswer:" | |
) | |
contexts_query = "" | |
for i in range(1, len(contexts)): | |
if len("\n\n---\n\n".join(contexts[:i])) >= limit_context: | |
# case when the first context is longer than limit, we allow just one first context. | |
contexts_query = "\n\n---\n\n".join(contexts[:i-1]) | |
break | |
if len(contexts_query) == 0: | |
if len(contexts) ==1 : | |
contexts_query = contexts[0] | |
if len(contexts) >1: | |
contexts_query = "\n\n---\n\n".join(contexts) | |
prompt = prompt_start + contexts_query + prompt_end | |
if verbose: | |
logger.logger.info(prompt) | |
return self.complete(prompt) | |
def complete(self, prompt: str): | |
res = openai.Completion.create( | |
prompt=prompt,**self.engine_params | |
) | |
return res['choices'][0]['text'].strip() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment