Created
February 4, 2020 01:24
-
-
Save timhughes/91ede6930e0b624beae101857c934ead to your computer and use it in GitHub Desktop.
rate_limiting_bucket.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
TokenBucket implementation from http://code.activestate.com/recipes/511490-implementation-of-the-token-bucket-algorithm/ | |
""" | |
import asyncio | |
from time import time | |
from math import floor | |
import uvicorn | |
from fastapi import BackgroundTasks | |
from fastapi import FastAPI, BackgroundTasks | |
import logging | |
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG) | |
logger = logging.getLogger('someapp') | |
app = FastAPI(title="Some Web App", | |
description="Some web api", | |
version="2.5.0", | |
) | |
class TokenBucket(object): | |
"""An implementation of the token bucket algorithm. | |
http://code.activestate.com/recipes/511490-implementation-of-the-token-bucket-algorithm/ | |
>>> bucket = TokenBucket(80, 0.5) | |
>>> print bucket.consume(10) | |
True | |
>>> print bucket.consume(90) | |
False | |
""" | |
def __init__(self, tokens, fill_rate): | |
"""tokens is the total tokens in the bucket. fill_rate is the | |
rate in tokens/second that the bucket will be refilled.""" | |
self.capacity = int(tokens) | |
self._tokens = int(tokens) | |
self.fill_rate = float(fill_rate) | |
self.timestamp = time() | |
def consume(self, tokens): | |
"""Consume tokens from the bucket. Returns requested tokens if there were | |
sufficient tokens otherwise returns all that are in the bucket.""" | |
if tokens <= self.tokens: | |
self._tokens -= tokens | |
else: | |
tokens = self.tokens | |
self._tokens -= self.tokens | |
return tokens | |
def get_tokens(self): | |
now = int(time()) | |
if self._tokens < self.capacity: | |
delta = self.fill_rate * (now - self.timestamp) | |
self._tokens = min(self.capacity, floor(self._tokens + delta)) | |
self.timestamp = now | |
return self._tokens | |
tokens = property(get_tokens) | |
async def get_tokens(capacity: int = 10, fill_rate: int = 1, request_qty: int = 1): | |
response_qty = 0 | |
if capacity: | |
bucket = TokenBucket(capacity, fill_rate) | |
while response_qty < request_qty: | |
logger.info( f"Bucket has {bucket.tokens}. We need {request_qty} tokens") | |
response_qty += bucket.consume(request_qty) | |
logger.info(f"We now have {response_qty} tokens") | |
if not response_qty >= request_qty: | |
await asyncio.sleep(1) | |
return response_qty | |
@app.get("/") | |
async def root(): | |
return {"message": "Hello World"} | |
@app.get("/tokens") | |
async def tokens(background_tasks: BackgroundTasks, capacity: int = 100, | |
fill_rate: int = 1, request_qty: int = 1, ): | |
background_tasks.add_task(get_tokens, capacity, fill_rate, request_qty) | |
background_tasks.tasks | |
return {"message": "Notification sent in the background"} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment