Created
May 20, 2025 13:51
-
-
Save exhuma/5076e7100a7a1f96cd18ad2f1f74473e to your computer and use it in GitHub Desktop.
Rate limiter for ASGI apps. Without dependency on any higher-level framework.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# cspell: ignore ASGI | |
""" | |
A simple rate-limiting middleware for ASGI applications. | |
This middleware uses the `limits` library to enforce rate limits on incoming | |
requests. It allows you to specify a default rate limit and custom limits for | |
specific client-IPs on specific paths. | |
Usage:: | |
app = FastAPI() | |
app.add_middleware( | |
RateLimit, | |
default_limit="10/1 minute", | |
limits={ | |
r"^\\/$": "5/30 seconds", | |
}, | |
) | |
""" | |
import re | |
import time | |
from limits import RateLimitItem, parse, storage, strategies | |
LIMITS_STORAGE = storage.MemoryStorage() | |
LIMITER = strategies.FixedWindowRateLimiter(storage=LIMITS_STORAGE) | |
ONE_PER_MINUTE = parse("1/minute") | |
def _get_x_headers(limit, client, path): | |
""" | |
Get the X-RateLimit headers for the given limit and stats. | |
:param limit: The rate limit item. | |
:param stats: The rate limit stats. | |
:return: A list of tuples containing the X-RateLimit headers. | |
""" | |
stats = LIMITER.get_window_stats(limit, "path", client, path) | |
output = [ | |
(b"X-RateLimit-Limit", str(limit).encode()), | |
(b"X-RateLimit-Remaining", str(stats.remaining).encode()), | |
(b"X-RateLimit-Reset", str(int(stats.reset_time)).encode()), | |
] | |
if stats.remaining == 0: | |
output.append( | |
(b"Retry-After", str(int(stats.reset_time - time.time())).encode()) | |
) | |
return output | |
async def _rate_limit_exceeded_response( | |
limit: RateLimitItem, send, client, path | |
): | |
response = b"Rate limit %r exceeded. Try again later.\n" % (limit) | |
x_headers = _get_x_headers(limit, client, path) | |
headers = [(b"Content-Type", b"text/plain")] + x_headers | |
await send( | |
{ | |
"type": "http.response.start", | |
"status": 429, | |
"headers": headers, | |
} | |
) | |
await send( | |
{ | |
"type": "http.response.body", | |
"body": response, | |
} | |
) | |
class RateLimit: | |
""" | |
Rate limiting middleware for ASGI applications. | |
This middleware uses the `limits` library to enforce rate limits on incoming | |
requests. It allows you to specify a default rate limit and custom limits | |
for specific paths. | |
:param app: The ASGI application to wrap. | |
:param default_limit: The default rate limit to apply to all paths. | |
:param limits: A mapping from regular expressions of paths to limits. | |
""" | |
_limits: dict[re.Pattern, RateLimitItem] = {} | |
def __init__(self, app, *, default_limit: str, limits: dict[str, str]): | |
self._limits = {re.compile(k): parse(v) for k, v in limits.items()} | |
self._default_limit = parse(default_limit) | |
self.app = app | |
async def __call__(self, scope, receive, send): | |
if scope["type"] == "http": | |
path = scope.get("path", "") | |
client = scope.get("client", ("", 0))[0] | |
may_consume = True | |
limit = self._default_limit | |
for pattern, path_limit in self._limits.items(): | |
if pattern.match(path): | |
limit = path_limit | |
break | |
may_consume = LIMITER.hit(limit, "path", client, path) | |
if not may_consume: | |
await _rate_limit_exceeded_response(limit, send, client, path) | |
return | |
x_headers = _get_x_headers(limit, client, path) | |
async def send_with_rate_limit_headers(message): | |
if message["type"] == "http.response.start": | |
# Merge downstream headers with rate-limit headers | |
orig_headers = dict(message.get("headers", [])) | |
for k, v in x_headers: | |
orig_headers[k] = v | |
# Convert back to list of tuples | |
message["headers"] = list(orig_headers.items()) | |
await send(message) | |
await self.app(scope, receive, send_with_rate_limit_headers) | |
else: | |
await self.app(scope, receive, send) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment