Created
April 18, 2023 15:08
-
-
Save pedrovhb/095b5ec8b6786ad792d6399cf435f7fb to your computer and use it in GitHub Desktop.
A standalone implementation of interacting with the OpenAI ChatGPT API. Uses async iterators for streaming responses and supports multiple concurrent generations with a single API call.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import httpx | |
import typing | |
from httpx import AsyncClient | |
import tiktoken | |
API_BASE_URL = "https://api.openai.com" | |
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] | |
MODEL = "gpt-3.5-turbo" | |
client = AsyncClient( | |
base_url=API_BASE_URL, | |
headers={ | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {OPENAI_API_KEY}", | |
}, | |
) | |
def gen_logit_bias_dict(biases: dict[str, float], model: str = MODEL) -> dict[int, int]: | |
encoder = tiktoken.encoding_for_model(model) | |
# We'll use the range -1 to 1 for the bias values, so we'll need to scale the | |
# values in the biases dict to the range -100 to 100. | |
# Todo - make this more robust by removing also all the other tokens that consist of some word | |
# plus a punctuation mark or whitespace | |
return {encoder.encode(token)[0]: int(bias * 100) for token, bias in biases.items()} | |
async def chat_completion_async( | |
model: str, | |
messages: typing.List[dict], | |
temperature: float = 1, | |
top_p: float = 1, | |
n: int = 1, | |
stop: typing.Optional[typing.Union[str, typing.List[str]]] = None, | |
max_tokens: typing.Optional[int] = None, | |
presence_penalty: float = 0, | |
frequency_penalty: float = 0, | |
logit_bias: typing.Optional[dict] = None, | |
) -> typing.AsyncIterator[typing.Union[dict, None]]: | |
"""Stream chat completion results from the OpenAI API. | |
Args: | |
model: The model to use for completion. | |
messages: A list of messages to use as context. | |
Example messages: | |
[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "Who won the world series in 2020?"}, | |
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, | |
{"role": "user", "content": "Where was it played?"} | |
] | |
temperature: What sampling temperature to use, between 0 and 2. | |
Higher values like 0.8 will make the output more random, while lower values like 0.2 | |
will make it more focused and deterministic. We generally recommend altering this or top_p but not both. | |
top_p: An alternative to sampling with temperature, called nucleus sampling, where the model | |
considers the results of the tokens with top_p probability mass. | |
So 0.1 means only the tokens comprising the top 10% probability mass are considered. | |
n: The number of choices to return. | |
stop: Up to 4 sequences where the API will stop generating further tokens. | |
max_tokens: The maximum number of tokens to generate in the chat completion. | |
presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on | |
whether they appear in the text so far, increasing the model's likelihood to talk about new topics. | |
frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing | |
frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | |
logit_bias: Modify the likelihood of specified tokens appearing in the completion. | |
Returns: | |
AsyncIterator[dict]: An async iterator that yields server-sent events as they arrive. | |
""" | |
data = dict( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
top_p=top_p, | |
n=n, | |
stream=True, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
) | |
if stop is not None: | |
data["stop"] = stop | |
if max_tokens is not None: | |
data["max_tokens"] = max_tokens | |
if logit_bias is not None: | |
data["logit_bias"] = logit_bias | |
async with client.stream("POST", "/v1/chat/completions", data=json.dumps(data)) as response: | |
response: httpx.Response | |
objs = { | |
"choices": {i: {"content": [], "finish_reason": None} for i in range(n)}, | |
"model": model, | |
} | |
buffer = bytearray() | |
async for chunk in response.aiter_bytes(): | |
# Parse the JSON contents of SSE | |
# Chunk example: | |
# b'data: {"id":"chatcmpl-76fSMKiKMWvkA790WN9ZFY7Mf8Q0W","object":"chat.completion.chunk", | |
# "created":1681823682,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0, | |
# "finish_reason":null}]}\n\ndata: {"id":"chatcmpl-76fSMKiKMWvkA790WN9ZFY7Mf8Q0W","object": | |
# "chat.completion.chunk","created":1681823682,"model":"gpt-3.5-turbo-0301","choices":[{"delta": | |
# {"content":"Hi"},"index":0,"finish_reason":null}]}\n\n' | |
lines = chunk.decode("utf-8").splitlines() | |
for line in lines: | |
line = line.strip() | |
if not line: | |
continue | |
if line == "data: [DONE]": | |
return | |
elif line.startswith("data: "): | |
line = line[6:] | |
elif buffer: | |
line = buffer.decode("utf-8") + line | |
buffer.clear() | |
else: | |
raise ValueError(f"Unexpected line: {line}") | |
try: | |
obj = json.loads(line) | |
except json.JSONDecodeError: | |
buffer.extend(line.encode("utf-8")) | |
continue | |
for choice in obj["choices"]: | |
i = choice["index"] | |
if "delta" in choice: | |
for key, value in choice["delta"].items(): | |
if key == "content": | |
objs["choices"][i]["content"].append(value) | |
else: | |
objs["choices"][i].update({key: value}) | |
objs["choices"][i]["finish_reason"] = choice["finish_reason"] | |
objs["model"] = obj["model"] | |
objs["created"] = obj["created"] | |
objs["id"] = obj["id"] | |
yield objs | |
def chat_completion_text_aiters( | |
model: str, | |
messages: typing.List[dict], | |
temperature: float = 1, | |
top_p: float = 1, | |
n: int = 1, | |
stop: typing.Optional[typing.Union[str, typing.List[str]]] = None, | |
max_tokens: typing.Optional[int] = None, | |
presence_penalty: float = 0, | |
frequency_penalty: float = 0, | |
logit_bias: typing.Optional[dict] = None, | |
) -> tuple[typing.AsyncIterator[str], ...]: | |
"""A convenience function that returns a tuple of async iterators for each choice. | |
See Also: | |
:func:`chat_completion_aiter` for the full list of parameters. | |
""" | |
chat_completion = chat_completion_async( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
top_p=top_p, | |
n=n, | |
stop=stop, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
logit_bias=logit_bias, | |
) | |
async def text_aiter(q: asyncio.Queue) -> typing.AsyncIterator[str]: | |
prev_content = "" | |
while True: | |
item = await q.get() | |
if item is None: | |
q.task_done() | |
break | |
if item != prev_content: | |
prev_content = item | |
yield item | |
q.task_done() | |
queues = [asyncio.Queue() for _ in range(n)] | |
aiters = tuple(text_aiter(q) for q in queues) | |
async def text_aiter_wrapper() -> None: | |
async for event in chat_completion: | |
for choice, queue in zip(event["choices"].values(), queues): | |
await queue.put("".join(choice["content"])) | |
await asyncio.gather(*[queue.put(None) for queue in queues]) | |
await asyncio.gather(*[queue.join() for queue in queues]) | |
asyncio.create_task(text_aiter_wrapper()) | |
return aiters | |
async def main(): | |
messages = [ | |
{ | |
"role": "user", | |
"content": "Tell me something interesting about ravens that is also true of alpacas.", | |
} | |
] | |
cont1, cont2 = chat_completion_text_aiters(MODEL, messages, max_tokens=32, n=2) | |
async def show_text(async_iterable: typing.AsyncIterator[str]) -> None: | |
async for text in async_iterable: | |
print(text) | |
await asyncio.gather(show_text(cont1), show_text(cont2)) | |
"""Example output: | |
Ravens and alpacas are both highly intelligent animals. Ravens have been observed | |
Ravens and alpacas are both highly intelligent animals. Ravens have been observed using | |
Ravens and alpacas are both highly intelligent animals with complex social behaviors. Ravens | |
Ravens and alpacas are both highly intelligent animals. Ravens have been observed using tools | |
Ravens and alpacas are both highly intelligent animals with complex social behaviors. Ravens are | |
Ravens and alpacas are both highly intelligent animals with complex social behaviors. Ravens are known | |
Ravens and alpacas are both highly intelligent animals. Ravens have been observed using tools and | |
Ravens and alpacas are both highly intelligent animals. Ravens have been observed using tools and solving | |
Ravens and alpacas are both highly intelligent animals with complex social behaviors. Ravens are known for | |
Ravens and alpacas are both highly intelligent animals. Ravens have been observed using tools and solving complex | |
Ravens and alpacas are both highly intelligent animals with complex social behaviors. Ravens are known for their | |
Ravens and alpacas are both highly intelligent animals. Ravens have been observed using tools and solving complex problems | |
Ravens and alpacas are both highly intelligent animals with complex social behaviors. Ravens are known for their problem | |
Ravens and alpacas are both highly intelligent animals with complex social behaviors. Ravens are known for their problem-solving | |
Ravens and alpacas are both highly intelligent animals. Ravens have been observed using tools and solving complex problems, | |
""" | |
if __name__ == "__main__": | |
import asyncio | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment