Last active
September 18, 2023 05:34
-
-
Save davidliyutong/355bb57d414a209a7291c97b592bfec2 to your computer and use it in GitHub Desktop.
Python helpler class to call api under rate limit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import heapq | |
from time import monotonic | |
from typing import Iterable, Dict, Any, List, Callable, Optional | |
import datetime | |
import inspect | |
class RateLimit: | |
def __init__(self, num: int, duration: datetime.timedelta): | |
""" | |
num: number of calls permited | |
duration: time window | |
""" | |
self._num = num | |
self._duration = duration | |
@property | |
def num(self): | |
""" | |
return number of calls permitted | |
""" | |
return self._num | |
@property | |
def duration(self): | |
""" | |
return duration in seconds | |
""" | |
return self._duration.seconds | |
class CostTuple: | |
def __init__(self, cost: float, timestamp: float): | |
""" | |
cost: cost of the job | |
timestamp: timestamp of the job | |
""" | |
self.cost = cost | |
self.timestamp = timestamp | |
def __tuple__(self): | |
return self.cost, self.timestamp | |
def __dict__(self): | |
return {"cost": self.cost, "timestamp": self.timestamp} | |
def __repr__(self): | |
return f"CostTuple(cost={self.cost}, timestamp={self.timestamp})" | |
def __lt__(self, other): | |
return self.timestamp < other.timestamp | |
class CostQueue: | |
""" | |
A sliding window priority queue, coroutine-safe | |
""" | |
def __init__(self): | |
self._queue = [] | |
self._mutex = asyncio.Lock() | |
heapq.heapify(self._queue) | |
async def put(self, cost_tuple: CostTuple): | |
async with self._mutex: | |
heapq.heappush(self._queue, cost_tuple) | |
async def get(self): | |
async with self._mutex: | |
return heapq.heappop(self._queue) | |
async def peek(self): | |
async with self._mutex: | |
return self._queue[0] | |
def empty(self): | |
return not self._queue | |
async def size(self): | |
async with self._mutex: | |
return len(self._queue) | |
async def sum(self): | |
async with self._mutex: | |
return sum([s.cost for s in self._queue]) | |
class RateLimitedCaller: | |
def __init__(self, limits: List[RateLimit]): | |
""" | |
limits: a list of RateLimit | |
""" | |
self.cost_queues = [CostQueue() for _ in limits] # a sliding window | |
self.limits_param = limits # rate limit converted | |
async def submit( | |
self, | |
fn, args: Iterable[Any] = None, | |
kwargs: Dict[str, Any] = None, | |
cost: float = None, | |
cost_cb: Callable = None | |
): | |
""" | |
fn: function to call | |
args: args to pass to fn | |
kwargs: kwargs to pass to fn | |
cost: pre-calculated cost | |
cost_cb: callback to adjust cost | |
""" | |
# check if we are over the limits and wait if necessary | |
cost_tuples: List[Optional[CostTuple]] = [None for _ in self.limits_param] | |
for idx, (q, limit_param) in enumerate(zip(self.cost_queues, self.limits_param)): | |
idx: int | |
q: CostQueue | |
while True: | |
# clear out dated | |
current_time = monotonic() | |
while not q.empty() and current_time - (await q.peek()).timestamp > limit_param.duration: | |
await q.get() | |
if (await q.sum()) >= limit_param.num: | |
# if busy, sleep until the oldest timestamp is out of the window | |
await asyncio.sleep(limit_param.duration - current_time - (await q.peek()).timestamp) | |
else: | |
# else, add the job, pass this limit check | |
if cost is not None: | |
cost_tuple = CostTuple(float(cost), current_time) | |
else: | |
cost_tuple = CostTuple(1, current_time) | |
await q.put(cost_tuple) | |
cost_tuples[idx] = cost_tuple | |
break | |
# call the job | |
args = () if args is None else args | |
kwargs = {} if kwargs is None else kwargs | |
if fn is None: | |
return None | |
if inspect.iscoroutinefunction(fn): | |
res = await fn(*args, **kwargs) | |
else: | |
res = fn(*args, **kwargs) | |
# calculate cost after the job is done | |
if cost_cb is not None: | |
for cost_tuple in cost_tuples: | |
if cost_tuple is not None: | |
cost_cb(cost_tuple, res, *args, **kwargs) # this will modify cost_tuple.cost | |
return res | |
async def wait(self, jobs): | |
""" | |
jobs: a list of jobs | |
""" | |
return await asyncio.gather(*jobs) | |
async def main(): | |
import time | |
import httpx | |
import requests | |
# define two tasks | |
def task1(x: int, y: int): | |
""" | |
simulate a heavy computation task | |
""" | |
print(f"x({x})+y({y})={x + y}") | |
for _ in range(100000): | |
# simulate a long running task | |
p = x + y | |
return x + y | |
def task2(path: str, msg: str): | |
""" | |
write msg to local file | |
""" | |
print(f"write: {msg}") | |
with open(path, "a+") as f: | |
f.write(msg + "\n") | |
def task3(url): | |
""" | |
use requests to make sync requests | |
""" | |
res = requests.get(url) | |
print(len(res.text)) | |
return res.text | |
async def task4(url): | |
""" | |
use httpx to make async requests | |
""" | |
async with httpx.AsyncClient() as client: | |
r = await client.get(url) | |
print(len(r.text)) | |
return r.text | |
async def task5(url): | |
""" | |
use httpx to make async requests | |
""" | |
async with httpx.AsyncClient() as client: | |
r = await client.get(url) | |
print(len(r.text)) | |
return r.text | |
def cost5(cost_tuple, res, *args, **kwargs): | |
""" | |
modify cost_tuple.cost | |
""" | |
cost_tuple.cost = len(res) / 1000 | |
# create a limited caller | |
caller = RateLimitedCaller( | |
[ | |
RateLimit(2500, datetime.timedelta(seconds=15)), # 25 calls in 15 seconds | |
RateLimit(500, datetime.timedelta(seconds=1)) # 5 calls in 1 second (burst) | |
] | |
) | |
tasks = [ | |
(task1, lambda idx: (0, idx), lambda idx: None, lambda idx: 100, None), | |
(task2, lambda idx: None, lambda idx: {"path": "./log", "msg": str(idx)}, lambda idx: 1000, None), | |
(task3, lambda idx: ("https://www.baidu.com",), lambda idx: None, lambda idx: 1000, None), | |
(task4, lambda idx: ("https://www.baidu.com",), lambda idx: None, lambda idx: 1000, None), | |
(task5, lambda idx: ("https://www.baidu.com",), lambda idx: None, lambda idx: 1000, cost5), | |
] | |
# tasks = tasks[-1:] | |
# test1 | |
for task, vargs, vkwargs, cost, cost_cb in tasks: | |
start_t = time.time() | |
results = [] # save results | |
for i in range(100): | |
results.append(caller.submit(task, args=vargs(i), kwargs=vkwargs(i), cost=cost(i), cost_cb=cost_cb)) | |
print("all jobs submitted") | |
results = await caller.wait(results) # wait for all jobs to finish | |
print(results) | |
print(f"duration={time.time() - start_t}") | |
if __name__ == '__main__': | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment