Last active
December 3, 2024 15:30
-
-
Save Andrew-Chen-Wang/67c68b2392001d486551e1e6660b538f to your computer and use it in GitHub Desktop.
Parallel calling OpenAI with rate limit handling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Created by: @Andrew-Chen-Wang | |
Parallel calling OpenAI with rate limit handling. | |
Required packages: | |
- openai | |
- tenacity (optional if you remove the decorators. Useful for 500 errors) | |
Usage: | |
You must call with with_raw_response to get rate limit headers like: | |
client.chat.completions.with_raw_response.create | |
To get the typical ChatCompletion response object, simply run: | |
```python | |
from openai.types.chat import ChatCompletion | |
r = await call_openai(client.chat.completions.with_raw_response.create, params) | |
r.parse(to=ChatCompletion) | |
``` | |
Warning: | |
This is meant for a single threaded application. If you have multiple servers, | |
you'll want to replace this locking mechanism with either: | |
1. To continue using this script, a dedicated server that handles all your OpenAI calls | |
e.g. API Gateway (recommended approach) | |
2. A distributed lock stored somewhere like Redis (must rewrite the locking mechanism here) | |
""" | |
import asyncio | |
import re | |
from contextvars import ContextVar | |
from typing import Any, Awaitable, Callable, TypedDict | |
import openai | |
from openai._legacy_response import LegacyAPIResponse | |
from tenacity import ( | |
retry, | |
stop_after_attempt, | |
wait_random_exponential, | |
) | |
# Rate limiting | |
rateLimitRequestsMax: ContextVar[int] = ContextVar("RateLimitRequestsMax", default=5000) | |
rateLimitTokensMax: ContextVar[int] = ContextVar("RateLimitTokensMax", default=160000) | |
rateLimitRemainingRequests: ContextVar[int] = ContextVar("RateLimitRemainingRequests", default=5000) | |
rateLimitRemainingTokens: ContextVar[int] = ContextVar("RateLimitRemainingTokens", default=160000) | |
# in seconds | |
rateLimitResetRequests: ContextVar[float] = ContextVar("RateLimitResetRequests", default=0) | |
# in seconds | |
rateLimitResetTokens: ContextVar[float] = ContextVar("RateLimitResetTokens", default=0) | |
# Ensuring only one modification happens for OpenAI results at a time | |
openai_lock = asyncio.Lock() | |
def parse_time_left(time_str: str) -> float: | |
""" | |
Returns time left in seconds. Allowed formats: | |
1s | |
12ms | |
6m12s | |
13m0s1ms | |
1d2h3m4s | |
1d4s3ms | |
""" | |
# Define regex patterns for days, hours, minutes, seconds, and milliseconds | |
patterns = { | |
"days": r"(\d+)d", | |
"hours": r"(\d+)h", | |
"minutes": r"(\d+)m(?![s])", # Negative lookahead to exclude 'ms' | |
"seconds": r"(\d+\.?\d*)s(?![m])", # Negative lookahead and handles decimals | |
"milliseconds": r"(\d+)ms", | |
} | |
# Initialize total time in seconds | |
total_seconds = 0 | |
# Loop through each time unit, find matches, and convert to seconds | |
for unit, pattern in patterns.items(): | |
match = re.search(pattern, time_str) | |
if match: | |
value = int(match.group(1)) | |
if unit == "days": | |
total_seconds += value * 86400 # 1 day = 86400 seconds | |
elif unit == "hours": | |
total_seconds += value * 3600 # 1 hour = 3600 seconds | |
elif unit == "minutes": | |
total_seconds += value * 60 # 1 minute = 60 seconds | |
elif unit == "seconds": | |
total_seconds += value | |
elif unit == "milliseconds": | |
total_seconds += value / 1000 # 1 millisecond = 0.001 seconds | |
return total_seconds | |
# We need to set all our values first | |
first_execution_lock = asyncio.Lock() | |
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) | |
async def call_openai( | |
async_func: Callable[[Any, Any], Awaitable[Any]], params: dict, event: asyncio.Event | |
) -> LegacyAPIResponse: | |
response: LegacyAPIResponse | None = None | |
async def run(): | |
try: | |
return await async_func(**params) | |
except openai.BadRequestError as e: | |
if e.code == "context_length_exceeded": | |
if params["model"] == "gpt-4-0125-preview": | |
raise e | |
params["model"] = "gpt-4-0125-preview" | |
return await async_func(**params) | |
raise e | |
if not event.is_set(): | |
should_run_without_lock = True | |
async with first_execution_lock: | |
# Only run the OpenAI call for the first to lock | |
if not event.is_set(): | |
response = await run() | |
should_run_without_lock = False | |
event.set() | |
if should_run_without_lock: | |
response = await run() | |
else: | |
async with openai_lock: | |
if rateLimitRemainingRequests.get() < 3: | |
print( | |
f"Rate limit reached for requests with {rateLimitRemainingRequests.get()} " | |
f"requests remaining. Sleeping for {rateLimitResetRequests.get()}s." | |
) | |
await asyncio.sleep(0.1 + rateLimitResetRequests.get()) | |
rateLimitRemainingRequests.set(rateLimitRequestsMax.get()) | |
if rateLimitRemainingTokens.get() < 12000: | |
print( | |
f"Rate limit reached for tokens with {rateLimitRemainingTokens.get()} tokens " | |
f"remaining. Sleeping for {rateLimitResetTokens.get()}s." | |
) | |
await asyncio.sleep(0.1 + rateLimitResetTokens.get()) | |
rateLimitRemainingTokens.set(rateLimitTokensMax.get()) | |
response = await run() | |
response = cast(LegacyAPIResponse, response) | |
async with openai_lock: | |
rateLimitTokensMax.set(int(response.headers["x-ratelimit-limit-tokens"])) | |
rateLimitRequestsMax.set(int(response.headers["x-ratelimit-limit-requests"])) | |
rateLimitRemainingRequests.set(int(response.headers["x-ratelimit-remaining-requests"])) | |
rateLimitRemainingTokens.set(int(response.headers["x-ratelimit-remaining-tokens"])) | |
rateLimitResetRequests.set(parse_time_left(response.headers["x-ratelimit-reset-requests"])) | |
rateLimitResetTokens.set(parse_time_left(response.headers["x-ratelimit-reset-tokens"])) | |
return response |
Had an error parsing. Here's the update:
def parse_time_left(time_str: str) -> float:
"""
Returns time left in seconds. Allowed formats:
1s
12ms
6m12s
13m0s1ms
1d2h3m4s
1d4s3ms
58.171s
"""
# Define regex patterns for days, hours, minutes, seconds, and milliseconds
patterns = {
"days": r"(\d+)d",
"hours": r"(\d+)h",
"minutes": r"(\d+)m(?![s])", # Negative lookahead to exclude 'ms'
"seconds": r"(\d+\.?\d*)s(?![m])", # Modified to handle decimal seconds
"milliseconds": r"(\d+)ms",
}
# Initialize total time in seconds
total_seconds = 0
# Loop through each time unit, find matches, and convert to seconds
for unit, pattern in patterns.items():
match = re.search(pattern, time_str)
if match:
value = float(match.group(1)) if unit == "seconds" else int(match.group(1)) # Use float for seconds
if unit == "days":
total_seconds += value * 86400 # 1 day = 86400 seconds
elif unit == "hours":
total_seconds += value * 3600 # 1 hour = 3600 seconds
elif unit == "minutes":
total_seconds += value * 60 # 1 minute = 60 seconds
elif unit == "seconds":
total_seconds += value
elif unit == "milliseconds":
total_seconds += value / 1000 # 1 millisecond = 0.001 seconds
return total_seconds
Thanks @hansvdam
updated from the seconds key from r"(\d+)s(?![m])"
to r"(\d+\.?\d*)s(?![m])"
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks! How ridiculous it is that we have to parse the reset delay this way...