Last active
March 25, 2023 13:24
-
-
Save monk1337/6afb3861602784c1a91195259b309d8a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABCMeta, abstractmethod | |
from typing import List, Optional, Union, Dict | |
import tenacity | |
class Model(metaclass=ABCMeta): | |
name = "" | |
description = "" | |
def __init__( | |
self, | |
api_key: str, | |
model: str, | |
api_wait: int = None, | |
api_retry: int = None, | |
**kwargs | |
): | |
""" | |
Initializes the Model class with the required parameters and verifies the model is supported by the endpoint. | |
:param api_key: str, Model API key if needed for the endpoint | |
:param model: str, name of the LLM model to use for the endpoint | |
:param api_wait: int, maximum wait time for an API request before retrying (in seconds) | |
:param api_retry: int, number of times to retry an API request before failing | |
:param **kwargs: additional arguments to be passed to the OpenAI API call | |
""" | |
self.api_key = api_key | |
self.model = model | |
self.api_wait = api_wait | |
self.api_retry = api_retry | |
self._verify_model() | |
self.set_key(api_key) | |
@classmethod | |
@abstractmethod | |
def supported_models(cls) -> List[str]: | |
""" | |
Get a list of supported models for the endpoint | |
""" | |
raise NotImplementedError | |
@abstractmethod | |
def _verify_model(self): | |
""" | |
Verify the model is supported by the endpoint | |
""" | |
raise NotImplementedError | |
@abstractmethod | |
def set_key(self, api_key: str): | |
""" | |
Set endpoint API key if needed | |
""" | |
raise NotImplementedError | |
@abstractmethod | |
def set_model(self, model: str): | |
""" | |
Set model name for the endpoint | |
""" | |
raise NotImplementedError | |
@abstractmethod | |
def get_description(self) -> str: | |
""" | |
Get model description | |
""" | |
raise NotImplementedError | |
@abstractmethod | |
def get_endpoint(self) -> str: | |
""" | |
Get model endpoint | |
""" | |
raise NotImplementedError | |
@abstractmethod | |
def get_parameters(self) -> Dict[str, str]: | |
""" | |
Get model parameters | |
""" | |
raise NotImplementedError | |
@abstractmethod | |
def run(self, prompts: List[str]) -> List[str]: | |
""" | |
Run the LLM on the given prompt list. | |
:param prompts: List[str], list of prompts to run on the LLM | |
:returns: List[str], list of responses from the LLM | |
""" | |
raise NotImplementedError | |
@abstractmethod | |
def model_output(self, response): | |
""" | |
Get the model output from the response | |
""" | |
raise NotImplementedError | |
def _retry_decorator(self): | |
""" | |
Decorator function for retrying API requests if they fail | |
""" | |
return tenacity.retry( | |
wait=tenacity.wait_random_exponential( | |
multiplier=0.3, exp_base=3, max=self.api_wait | |
), | |
stop=tenacity.stop_after_attempt(self.api_retry), | |
) | |
def execute_with_retry(self, *args, **kwargs): | |
""" | |
Decorated version of the `run` method with the retry logic | |
""" | |
decorated_run = self._retry_decorator()(self.run) | |
return decorated_run(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, List, Optional, Tuple, Union | |
import openai | |
import tiktoken | |
from parser import Parser | |
from base_class import Model | |
class OpenAI_Complete(Model): | |
name = "OpenAI" | |
description = "OpenAI API for text completion using various models" | |
def __init__( | |
self, | |
api_key: str, | |
model: str = "text-davinci-003", | |
temperature: float = 0.7, | |
top_p: float = 1, | |
n: int = 1, | |
logprobs: Optional[int] = None, | |
echo: bool = False, | |
stop: Optional[Union[str, List[str]]] = None, | |
presence_penalty: float = 0, | |
frequency_penalty: float = 0, | |
best_of: int = 1, | |
logit_bias: Optional[Dict[str, int]] = None, | |
request_timeout: Union[float, Tuple[float, float]] = None, | |
api_wait=None, | |
api_retry=None, | |
max_completion_length: int = 20, | |
): | |
super().__init__(api_key, model, api_wait, api_retry) | |
self.temperature = temperature | |
self.top_p = top_p | |
self.n = n | |
self.logprobs = logprobs | |
self.echo = echo | |
self.stop = stop | |
self.presence_penalty = presence_penalty | |
self.frequency_penalty = frequency_penalty | |
self.best_of = best_of | |
self.logit_bias = logit_bias or {} | |
self.request_timeout = request_timeout | |
self.max_completion_length = max_completion_length | |
self._verify_model() | |
self.encoder = tiktoken.encoding_for_model(self.model) | |
self.max_tokens = self.default_max_tokens(self.model) | |
self.parser = Parser() | |
self.set_key(self.api_key) | |
@classmethod | |
def supported_models(cls) -> Dict[str, str]: | |
return { | |
"text-davinci-003": "text-davinci-003 can do any language task with better quality, longer output, and consistent instruction-following than the curie, babbage, or ada models. Also supports inserting completions within text.", | |
"text-curie-001": "text-curie-001 is very capable, faster and lower cost than Davinci.", | |
"text-babbage-001": "text-babbage-001 is capable of straightforward tasks, very fast, and lower cost.", | |
"text-ada-001": "text-ada-001 is capable of very simple tasks, usually the fastest model in the GPT-3 series, and lowest cost.", | |
} | |
def default_max_tokens(self, model_name: str) -> int: | |
token_dict = { | |
"text-davinci-003": 4000, | |
"text-curie-001": 2048, | |
"text-babbage-001": 2048, | |
"text-ada-001": 2048, | |
} | |
return token_dict[model_name] | |
def _verify_model(self): | |
if self.model not in self.supported_models(): | |
raise ValueError(f"Unsupported model: {self.model}") | |
def set_key(self, api_key: str): | |
self._openai = openai | |
self._openai.api_key = api_key | |
def set_model(self, model: str): | |
self.model = model | |
self._verify_model() | |
def get_description(self) -> str: | |
return self.supported_models()[self.model] | |
def get_endpoint(self) -> str: | |
model = openai.Model.retrieve(self.model) | |
return model["id"] | |
def calculate_max_tokens(self, prompt: str) -> int: | |
prompt_tokens = len(self.encoder.encode(prompt)) | |
max_tokens = self.default_max_tokens(self.model) - prompt_tokens | |
return max_tokens | |
def model_output(self, response: Dict) -> Dict: | |
data = {} | |
data["text"] = response["choices"][0]["text"] | |
data["usage"] = dict(response["usage"]) | |
return data | |
def model_output_with_parser(self, response: Dict, max_completion_length: int) -> Dict: | |
data = {} | |
data["text"] = self.parser.escaped_(response["choices"][0]["text"]) | |
data["usage"] = dict(response["usage"]) | |
data["parsed"] = self.parser.fit(data["text"], max_completion_length) | |
return data | |
def get_parameters( | |
self, | |
) -> Dict[str, Union[str, int, float, List[str], Dict[str, int]]]: | |
return { | |
"max_tokens": self.max_tokens, | |
"temperature": self.temperature, | |
"top_p": self.top_p, | |
"n": self.n, | |
"logprobs": self.logprobs, | |
"echo": self.echo, | |
"stop": self.stop, | |
"presence_penalty": self.presence_penalty, | |
"frequency_penalty": self.frequency_penalty, | |
"best_of": self.best_of, | |
"logit_bias": self.logit_bias, | |
"request_timeout": self.request_timeout, | |
} | |
def run(self, prompts: List[str]) -> List[Optional[str]]: | |
""" | |
prompts: The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays. | |
""" | |
result = [] | |
for prompt in prompts: | |
# Automatically calculate max output tokens if not specified | |
max_tokens = self.calculate_max_tokens(prompt) | |
response = self._openai.Completion.create( | |
model=self.model, | |
prompt=prompt, | |
max_tokens=max_tokens, | |
temperature=self.temperature, | |
top_p=self.top_p, | |
n=self.n, | |
logprobs=self.logprobs, | |
echo=self.echo, | |
stop=self.stop, | |
best_of=self.best_of, | |
logit_bias=self.logit_bias, | |
request_timeout=self.request_timeout, | |
presence_penalty=self.presence_penalty, | |
frequency_penalty=self.frequency_penalty, | |
) | |
result.append(response) | |
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
from operator import itemgetter | |
from typing import Any, Dict, List, Union, Optional | |
import re | |
import ast | |
class Parser: | |
""" | |
A class to parse incomplete JSON objects and provide possible completions. | |
Methods | |
------- | |
is_valid_json() -> bool: | |
Checks if a string is valid JSON. | |
get_combinations() -> List[str]: | |
Returns all possible combinations of } and ] characters up to length n. | |
complete_json_object() -> Any: | |
Completes a JSON object string by appending a completion string. | |
get_possible_completions() -> Union[Dict[str, Any], List[Any]]: | |
Returns a list of possible completions for a JSON object string. | |
fit() -> Dict[str, Any]: | |
Tries to parse the input JSON string and complete it if it is incomplete. | |
find_max_length() -> Dict[str, List[Any]]: | |
Returns a dictionary containing the element with the maximum length in the input list. | |
""" | |
def __init__(self): | |
pass | |
def is_valid_json(self, input_str: str) -> bool: | |
""" | |
Check if the input string is valid JSON. | |
Parameters | |
---------- | |
input_str : str | |
The string to check for validity. | |
Returns | |
------- | |
bool | |
Returns True if the input string is valid JSON, otherwise False. | |
Notes | |
----- | |
This function uses the `json` module to check if the input string is valid JSON. | |
It evaluates the input string using `eval()`, and if it successfully loads | |
a JSON object (either a dictionary or a list), it returns True. Otherwise, it | |
returns False. | |
Examples | |
-------- | |
>>> validator = Parser() | |
>>> validator.is_valid_json('{"name": "Alice", "age": 30}') | |
True | |
>>> validator.is_valid_json('[1, 2, 3, 4]') | |
True | |
>>> validator.is_valid_json('{"name": "Bob", "age": }') | |
False | |
>>> validator.is_valid_json('not a JSON string') | |
False | |
""" | |
try: | |
output = eval(input_str) | |
if isinstance(output, (dict, list)): | |
return True | |
else: | |
return False | |
except Exception: | |
return False | |
def escaped_(self, data: str) -> str: | |
if "'" in data: | |
escaped_str = re.sub(r"(?<=\w)(')(?=\w)", r"\"", data) | |
else: | |
escaped_str = re.sub(r'(?<=\w)(")(?=\w)', r"\'", data) | |
return escaped_str | |
def get_combinations( | |
self, candidate_marks: List[str], n: int, should_end_mark: Optional[str] = None | |
) -> List[str]: | |
""" | |
Return all possible combinations of candidate marks up to length n. | |
Parameters | |
---------- | |
candidate_marks : list of str | |
Candidate marks to combine. | |
n : int | |
The maximum length of the combinations. | |
should_end_mark : str or None, optional | |
If provided, only combinations that end with this mark will be returned, by default None. | |
Returns | |
------- | |
list of str | |
A list of all possible combinations of candidate marks up to length n. | |
""" | |
combinations = [] | |
for i in range(1, n): | |
for comb in itertools.product(candidate_marks, repeat=i): | |
if should_end_mark is not None and comb[-1] != should_end_mark: | |
# cut down search space | |
continue | |
combinations.append("".join(comb)) | |
return combinations | |
def complete_json_object(self, json_str: str, completion_str: str) -> Any: | |
""" | |
Complete a JSON object string by appending a completion string. | |
Parameters | |
---------- | |
json_str : str | |
The original JSON object string. | |
completion_str : str | |
The completion string to append. | |
Returns | |
------- | |
Any | |
The completed JSON object as a Python object. | |
Raises | |
------ | |
ValueError | |
If the JSON object string cannot be fixed. | |
Notes: | |
------ | |
- This function appends the `completion_str` to the end of `json_str` until a valid JSON object can be obtained. If `json_str` is an invalid JSON object string, the function will remove one character from the end of `json_str` and try again until it either finds a valid JSON object string or until there are no more characters left to remove. | |
Examples | |
-------- | |
>>> complete_json_object('{"a": 1, "b": 2', '}') | |
{'a': 1, 'b': 2} | |
>>> complete_json_object('{"a": 1, "b": 2}}}}}}}}}}}}}}}}}}}', '') | |
{'a': 1, 'b': 2} | |
>>> complete_json_object('{"a": 1, "b": 2', '') | |
Traceback (most recent call last): | |
... | |
ValueError: Couldn't fix JSON | |
""" | |
while True: | |
if not json_str: | |
raise ValueError("Couldn't fix JSON") | |
try: | |
complete_json_str = json_str + completion_str | |
python_obj = eval(complete_json_str) | |
except Exception: | |
json_str = json_str[:-1] | |
continue | |
return python_obj | |
def get_possible_completions( | |
self, json_str: str, max_completion_length: int = 5 | |
) -> Union[Dict[str, Any], List[Any]]: | |
""" | |
Returns a list of possible completions for a JSON object string. | |
Parameters | |
---------- | |
json_str : str | |
The JSON object string | |
max_completion_length : int, optional | |
The maximum length of the completion strings to try (default is 5) | |
Returns | |
------- | |
Union[Dict[str, Any], List[Any]] | |
If the completion strings are objects, returns a dictionary with 'completion' and 'suggestions' keys. | |
If the completion strings are arrays, returns a list of suggested completions. | |
""" | |
candidate_marks = ["}", "]"] | |
if "[" not in json_str: | |
candidate_marks.remove("]") | |
if "{" not in json_str: | |
candidate_marks.remove("}") | |
# specify the mark should end with | |
should_end_mark = "]" if json_str.strip()[0] == "[" else "}" | |
completions = [] | |
for completion_str in self.get_combinations( | |
candidate_marks, max_completion_length, should_end_mark=should_end_mark | |
): | |
try: | |
completed_obj = self.complete_json_object(json_str, completion_str) | |
completions.append(completed_obj) | |
except Exception: | |
pass | |
return self.find_max_length(completions) | |
def fit(self, json_str: str, max_completion_length: int = 5) -> Dict[str, Any]: | |
""" | |
Tries to parse the input JSON string and complete it if it is incomplete. | |
Parameters | |
---------- | |
json_str : str | |
The input JSON string | |
max_completion_length : int, optional | |
The maximum length of the completion strings to try (default is 5) | |
Returns | |
------- | |
Dict[str, Any] | |
A dictionary with 'status' and 'data' keys. If the status is 'completed', the 'data' | |
key will contain the completed object and an empty list of suggestions. If the status | |
is 'failed', the 'data' key will contain an error message string. If the status is | |
'incomplete', the 'data' key will contain a list of possible completions. | |
""" | |
try: | |
output = eval(json_str) | |
return { | |
"status": "completed", | |
"object_type": type(output), | |
"data": {"completion": output, "suggestions": []}, | |
} | |
except Exception: | |
# remove tail braces or brackets to speed up searching. | |
json_str = re.sub(r"[\[\]\{\}\s]+$", "", json_str) | |
try: | |
output = self.get_possible_completions( | |
json_str, max_completion_length=max_completion_length | |
) | |
return { | |
"status": "completed", | |
"object_type": type(output["completion"]), | |
"data": output, | |
} | |
except Exception as e: | |
return { | |
"status": "failed", | |
"object_type": None, | |
"data": {"error_message": str(e)}, | |
} | |
def find_max_length(self, data_list: List[Any]) -> Dict[str, List[Any]]: | |
""" | |
Returns a dictionary containing the element with the maximum length in the input list, | |
as well as a list of all elements sorted by length in descending order. | |
Parameters | |
---------- | |
data_list : list of any type | |
A list of elements to be compared by length | |
Returns | |
------- | |
dict | |
A dictionary with keys 'completion' and 'suggestions'. | |
The value of 'completion' key is the element with the maximum length in the input list. | |
The value of 'suggestions' key is a list of all elements sorted by length in descending order. | |
""" | |
# Create a dictionary where the keys are the indices of the elements in the input list | |
# and the values are the lengths of the corresponding elements. | |
length_dict = {i: len(str(element)) for i, element in enumerate(data_list)} | |
# Sort the dictionary by value (length) in descending order. | |
sorted_indices = sorted(length_dict.items(), key=itemgetter(1), reverse=True) | |
# Create a new dictionary with the element with the maximum length as the 'completion' value | |
# and a list of all elements sorted by length as the 'suggestions' value. | |
output_dict = { | |
"completion": data_list[sorted_indices[0][0]], | |
"suggestions": [data_list[i] for i, _ in sorted_indices], | |
} | |
return output_dict | |
def extract_complete_objects(self, string: str) -> List[Any]: | |
""" | |
Extracts all complete Python objects from a string. | |
Parameters | |
---------- | |
string : str | |
The string to extract objects from. | |
Returns | |
------- | |
List[Any] | |
A list of all complete Python objects found in the string. | |
""" | |
object_regex = r"(?<!\\)(\[[^][]*?(?<!\\)\]|\{[^{}]*\})" | |
# The regular expression pattern matches any string starting with an opening brace or bracket, | |
# followed by any number of non-brace and non-bracket characters, and ending with a closing brace | |
# or bracket that is not preceded by an odd number of backslash escape characters. | |
object_strings = [] | |
opening = {"{": 0, "[": 0} | |
closing = {"}": "{", "]": "["} | |
stack = [] | |
start = 0 | |
for match in re.finditer(object_regex, string): | |
if len(stack) == 0: | |
start = match.start() | |
stack.append(match.group(1)) | |
if match.group(1)[-1] in closing: | |
opening_bracket = closing[match.group(1)[-1]] | |
opening[opening_bracket] += 1 | |
if opening[opening_bracket] == len( | |
[bracket for bracket in opening.values() if bracket != 0] | |
): | |
object_strings.append(string[start : match.end()]) | |
stack = [] | |
opening = {"{": 0, "[": 0} | |
closing = {"}": "{", "]": "["} | |
if len(stack) > 0: | |
print(f"Error: Incomplete object at end of string: {stack[-1]}") | |
objects = [] | |
for object_string in object_strings: | |
try: | |
obj = ast.literal_eval(object_string) | |
# Use ast.literal_eval() to safely evaluate the string as a Python object. | |
objects.append(obj) | |
except (ValueError, SyntaxError) as e: | |
# If the string cannot be safely evaluated as a Python object, log an error and move on to the next object. | |
print(f"Error evaluating object string '{object_string}': {str(e)}") | |
pass | |
return objects |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
prompt = """You are a highly intelligent and accurate medical domain Named-entity recognition(NER) system. You take Passage as input and your task is to recognize and extract specific types of medical domain named entities in that given passage and classify into a set of entity types. Your output valid json and format is only [{{'T': type of entity from predefined entity types, 'E': entity in the input text}},...,{{'branch' : Appropriate branch of the passage ,'group': Appropriate Group of the passage}}] form, no other form. | |
Examples: | |
Input: The patient had abdominal pain and 30-pound weight loss then developed jaundice. He had epigastric's pain. A thin-slice CT scan was performed, which revealed a pancreatic's mass with involved lymph nodes and ring enhancing lesions with liver metastases | |
Output: [[{'T': 'SYMPTOM', 'E': 'abdominal pain'}, {'T': 'QUANTITY', 'E': '30-pound'}, {'T': 'SYMPTOM', 'E': 'jaundice'}, {'T': 'SYMPTOM', 'E': "epigastric\'s pain"}, {'T': 'TEST', 'E': 'thin-slice CT scan'}, {'T': 'ANATOMY', 'E': "pancreatic\'s mass"}, {'T': 'ANATOMY', 'E': 'ring enhancing lesions'}, {'T': 'ANATOMY', 'E': 'liver'}, {'T': 'DISEASE', 'E': 'metastases'}, {'branch': 'Health', 'group': 'Clinical medicine'}]] | |
Input: Dopamine (DA, a contraction of 3,4-dihydroxyphenethylamine) is a neuromodulatory molecule that plays several important roles in cells. It is an organic chemical of the catecholamine and phenethylamine families. Dopamine constitutes about 80% of the catecholamine content in the brain. It is an amine synthesized by removing a carboxyl group from a molecule of its precursor chemical, L-DOPA, which is synthesized in the brain and kidneys. Dopamine is also synthesized in plants and most animals. In the brain, dopamine functions as a neurotransmitter—a chemical released by neurons (nerve cells) to send signals to other nerve cells. Neurotransmitters are synthesized in specific regions of the brain, but affect many regions systemically. The brain includes several distinct dopamine pathways, one of which plays a major role in the motivational component of reward-motivated behavior. The anticipation of most types of rewards increases the level of dopamine in the brain,[4] and many addictive drugs increase dopamine release or block its reuptake into neurons following release.[5] Other brain dopamine pathways are involved in motor control and in controlling the release of various hormones. These pathways and cell groups form a dopamine system which is neuromodulatory.[5] | |
Output:""" | |
op_ = OpenAI_Complete(api_key='sk-aa', api_wait = 10, api_retry = 6) | |
dr = op_.execute_with_retry([prompt]) | |
op_.model_output_with_parser(dr[0], 20)['parsed']['data']['completion'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment