Skip to content

Instantly share code, notes, and snippets.

@paulo-raca
Last active November 8, 2019 22:05
Show Gist options
  • Save paulo-raca/25e153302cd696e63e2cd598a3933594 to your computer and use it in GitHub Desktop.
Save paulo-raca/25e153302cd696e63e2cd598a3933594 to your computer and use it in GitHub Desktop.
TokenBucket in Python-Asyncio
import asyncio
import time
class Timer:
def now():
raise NotImplemented()
async def wait(time):
raise NotImplemented()
class WallTimer(Timer):
def now(self):
#print(f"now() -> {time.time()}")
return time.time()
async def wait(self, time):
#print(f"wait({time})")
await asyncio.sleep(time)
class TestTimer(Timer):
def __init__(self):
self.now = 0
def now(self):
return self.now
async def wait(self, time):
self.now += time
class TokenBucket:
def __init__(self, initial_capacity=0, max_capacity=1, generation_rate=1, timer=WallTimer()):
self.current_capacity = initial_capacity
self.max_capacity = max_capacity
self.timer = timer
self.last_timestamp = timer.now()
self.generation_rate = generation_rate
self.current_capacity_lock = asyncio.Lock()
self.sync_lock = asyncio.Lock()
def _update_capactiy(self):
"""
Updates current capacity based on how long it has been
"""
now = self.timer.now()
elapsed = now - self.last_timestamp
self.last_timestamp = now
self.current_capacity += elapsed * self.generation_rate
self.current_capacity = min(self.current_capacity, self.max_capacity)
#print(f"_update_capactiy -> {self.current_capacity} / {self.max_capacity}")
async def tryAcquire(self, n=1):
"""
Tries to consume `n` immediately.
Returns True on success
"""
async with self.current_capacity_lock:
self._update_capactiy()
if n <= self.current_capacity:
self.current_capacity -= n
return True
return False
async def drain(self, n=None):
"""
Consumes as many tokens as possible (up to `n`) immediately.
Returns The number of tokens consumed
"""
async with self.current_capacity_lock:
self._update_capactiy()
if n is None or n > self.current_capacity:
n = self.current_capacity
self.current_capacity = 0
else:
self.current_capacity -= n
return n
async def acquire(self, amount=1):
# Serializes synchronous calls with fairness
# tryAcquire and drain() do not respect fairness!
async with self.sync_lock:
# Maybe increase max_capacity temporarialy to fit the desired amount of buckets
async with self.current_capacity_lock:
old_max_capacity = self.max_capacity
self.max_capacity = max(self.max_capacity, amount)
while True:
async with self.current_capacity_lock:
self._update_capactiy()
if amount <= self.current_capacity:
# Consume tokens tokens
self.current_capacity -= amount
# Restore original capacity
self.max_capacity = old_max_capacity
return
wait = (amount - self.current_capacity) / self.generation_rate
await self.timer.wait(wait)
async def main():
bucket = TokenBucket(generation_rate=10, max_capacity=10)
while True:
for i in range(20):
await bucket.acquire()
print(".", end='', flush=True)
print()
await asyncio.sleep(1)
asyncio.run(main(), debug=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment