Skip to content

Instantly share code, notes, and snippets.

@ljnmedium
Created September 29, 2023 07:36
Show Gist options
  • Save ljnmedium/385b2e0cfbe97e5a6836a31cd9c7edd4 to your computer and use it in GitHub Desktop.
Save ljnmedium/385b2e0cfbe97e5a6836a31cd9c7edd4 to your computer and use it in GitHub Desktop.
chat_doc.py
OPENAI_MODEL ='text-davinci-003'
class ChatDoc():
def __init__(self, model_name:str = OPENAI_MODEL, **model_params):
self.model = model_name
self.engine_params = dict(
model = model_name,
temperature=0,
max_tokens=400,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=None
)
if model_params:
self.engine_params.update(model_params)
@staticmethod
def connect_openai_api():
load_dotenv()
if 'OPENAI_API_KEY' in os.environ :
openai.api_key = os.getenv('OPENAI_API_KEY')
if openai.Model.list():
return True
return False
def prompt_completion(self, prompt: str, verbose=False):
if verbose:
logger.logger.info(prompt)
return self.complete(prompt)
def prompt_with_context(self, query: str, contexts: List[str], limit_context: int=5000, verbose=False):
# build our prompt with the retrieved contexts included
prompt_start = (
"Answer the question based on the context below.\n\n"+
"Context:\n"
)
prompt_end = (
f"\n\nQuestion: {query}\nAnswer:"
)
contexts_query = ""
for i in range(1, len(contexts)):
if len("\n\n---\n\n".join(contexts[:i])) >= limit_context:
# case when the first context is longer than limit, we allow just one first context.
contexts_query = "\n\n---\n\n".join(contexts[:i-1])
break
if len(contexts_query) == 0:
if len(contexts) ==1 :
contexts_query = contexts[0]
if len(contexts) >1:
contexts_query = "\n\n---\n\n".join(contexts)
prompt = prompt_start + contexts_query + prompt_end
if verbose:
logger.logger.info(prompt)
return self.complete(prompt)
def complete(self, prompt: str):
res = openai.Completion.create(
prompt=prompt,**self.engine_params
)
return res['choices'][0]['text'].strip()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment