Last active
November 8, 2019 22:05
-
-
Save paulo-raca/25e153302cd696e63e2cd598a3933594 to your computer and use it in GitHub Desktop.
TokenBucket in Python-Asyncio
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import time | |
class Timer: | |
def now(): | |
raise NotImplemented() | |
async def wait(time): | |
raise NotImplemented() | |
class WallTimer(Timer): | |
def now(self): | |
#print(f"now() -> {time.time()}") | |
return time.time() | |
async def wait(self, time): | |
#print(f"wait({time})") | |
await asyncio.sleep(time) | |
class TestTimer(Timer): | |
def __init__(self): | |
self.now = 0 | |
def now(self): | |
return self.now | |
async def wait(self, time): | |
self.now += time | |
class TokenBucket: | |
def __init__(self, initial_capacity=0, max_capacity=1, generation_rate=1, timer=WallTimer()): | |
self.current_capacity = initial_capacity | |
self.max_capacity = max_capacity | |
self.timer = timer | |
self.last_timestamp = timer.now() | |
self.generation_rate = generation_rate | |
self.current_capacity_lock = asyncio.Lock() | |
self.sync_lock = asyncio.Lock() | |
def _update_capactiy(self): | |
""" | |
Updates current capacity based on how long it has been | |
""" | |
now = self.timer.now() | |
elapsed = now - self.last_timestamp | |
self.last_timestamp = now | |
self.current_capacity += elapsed * self.generation_rate | |
self.current_capacity = min(self.current_capacity, self.max_capacity) | |
#print(f"_update_capactiy -> {self.current_capacity} / {self.max_capacity}") | |
async def tryAcquire(self, n=1): | |
""" | |
Tries to consume `n` immediately. | |
Returns True on success | |
""" | |
async with self.current_capacity_lock: | |
self._update_capactiy() | |
if n <= self.current_capacity: | |
self.current_capacity -= n | |
return True | |
return False | |
async def drain(self, n=None): | |
""" | |
Consumes as many tokens as possible (up to `n`) immediately. | |
Returns The number of tokens consumed | |
""" | |
async with self.current_capacity_lock: | |
self._update_capactiy() | |
if n is None or n > self.current_capacity: | |
n = self.current_capacity | |
self.current_capacity = 0 | |
else: | |
self.current_capacity -= n | |
return n | |
async def acquire(self, amount=1): | |
# Serializes synchronous calls with fairness | |
# tryAcquire and drain() do not respect fairness! | |
async with self.sync_lock: | |
# Maybe increase max_capacity temporarialy to fit the desired amount of buckets | |
async with self.current_capacity_lock: | |
old_max_capacity = self.max_capacity | |
self.max_capacity = max(self.max_capacity, amount) | |
while True: | |
async with self.current_capacity_lock: | |
self._update_capactiy() | |
if amount <= self.current_capacity: | |
# Consume tokens tokens | |
self.current_capacity -= amount | |
# Restore original capacity | |
self.max_capacity = old_max_capacity | |
return | |
wait = (amount - self.current_capacity) / self.generation_rate | |
await self.timer.wait(wait) | |
async def main(): | |
bucket = TokenBucket(generation_rate=10, max_capacity=10) | |
while True: | |
for i in range(20): | |
await bucket.acquire() | |
print(".", end='', flush=True) | |
print() | |
await asyncio.sleep(1) | |
asyncio.run(main(), debug=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment