Last active
September 28, 2018 08:13
-
-
Save hzj629206/e27590c6d27e269e4b587ae47e97f114 to your computer and use it in GitHub Desktop.
Django Rate Limiter via Redis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from __future__ import absolute_import, unicode_literals, division, print_function | |
import time | |
import hashlib | |
import itertools | |
import functools | |
from redis.exceptions import RedisError | |
from django.http import HttpResponse | |
try: | |
from django.core.cache import get_cache | |
except ImportError: # Django 1.9+ | |
from django.core.cache import caches | |
def get_cache(backend): | |
return caches[backend] | |
LUA_SCRIPT_SMOOTH = ''' | |
--[[ | |
A lua rate limiter script run in redis | |
use token bucket algorithm. | |
Algorithm explanation | |
1. key, use this key to find the token bucket in redis | |
2. there're several args should be passed in: | |
intervalPerPermit, time interval in millis between two token permits; | |
timeNow, timestamp in millis when running this lua script; | |
limit, the capacity limit of the token bucket; | |
interval, the time interval in millis of the token bucket; | |
]] -- | |
local effects = {} | |
for idx, key in ipairs(KEYS) do | |
local idxBase = (idx - 1) * 6 | |
local interval = tonumber(ARGV[idxBase + 1]) | |
local capacity = tonumber(ARGV[idxBase + 2]) | |
local nTokens = tonumber(ARGV[idxBase + 3]) | |
local timeNow = tonumber(ARGV[idxBase + 4]) | |
local expire = tonumber(ARGV[idxBase + 5]) | |
local intervalPerPermit = tonumber(ARGV[idxBase + 6]) | |
local bucket = redis.call('hgetall', key) | |
local burstTokens = nTokens | |
local currentTokens = -1 | |
local lastRefillTime = timeNow | |
if table.maxn(bucket) == 0 then | |
-- first check if bucket not exists, if yes, create a new one with full capacity, then grant access | |
currentTokens = burstTokens | |
redis.call('hset', key, 'lastRefillTime', timeNow) | |
elseif table.maxn(bucket) == 4 then | |
-- if bucket exists, first we try to refill the token bucket | |
local lastRefillTime, tokensRemaining = tonumber(bucket[2]), tonumber(bucket[4]) | |
if timeNow > lastRefillTime then | |
-- if timeNow larger than lastRefillTime, we should refill the token buckets | |
-- calculate the interval between timeNow and lastRefillTime | |
-- if the result is bigger than the interval of the token bucket, | |
-- refill the tokens to capacity capacity; | |
-- else calculate how much tokens should be refilled | |
local intervalSinceLast = timeNow - lastRefillTime | |
if intervalSinceLast > interval then | |
currentTokens = burstTokens | |
redis.call('hset', key, 'lastRefillTime', timeNow) | |
else | |
local grantedTokens = math.floor(intervalSinceLast / intervalPerPermit) | |
if grantedTokens > 0 then | |
-- ajust lastRefillTime, we want shift left the refill time. | |
local padMillis = math.fmod(intervalSinceLast, intervalPerPermit) | |
lastRefillTime = timeNow - padMillis | |
redis.call('hset', key, 'lastRefillTime', lastRefillTime) | |
end | |
currentTokens = math.min(grantedTokens + tokensRemaining, capacity) | |
end | |
else | |
-- if not, it means some other operation later than this call made the call first. | |
-- there is no need to refill the tokens. | |
currentTokens = tokensRemaining | |
end | |
end | |
assert(currentTokens >= 0) | |
if expire > 0 then | |
redis.call('expire', key, expire) | |
end | |
if nTokens > currentTokens then | |
-- we didn't consume any keys | |
redis.call('hset', key, 'tokensRemaining', currentTokens) | |
for i, effect in ipairs(effects) do | |
redis.call('hset', effect[1], 'tokensRemaining', effect[2]) | |
end | |
return {key, interval, capacity, currentTokens, lastRefillTime} | |
else | |
table.insert(effects, {key, currentTokens, nTokens}) | |
end | |
end | |
for i, effect in ipairs(effects) do | |
redis.call('hset', effect[1], 'tokensRemaining', effect[2] - effect[3]) | |
end | |
return {'', 0, 0, 0, 0} | |
''' | |
LUA_SCRIPT = ''' | |
--[[ | |
A lua rate limiter script run in redis | |
use token bucket algorithm. | |
Algorithm explanation | |
1. key, use this key to find the token bucket in redis | |
2. there're several args should be passed in: | |
timeNow, timestamp in millis when running this lua script; | |
capacity, the capacity limit of the token bucket; | |
interval, the time interval in millis of the token bucket; | |
]] -- | |
local effects = {} | |
for idx, key in ipairs(KEYS) do | |
local idxBase = (idx - 1) * 5 | |
local interval = tonumber(ARGV[idxBase + 1]) | |
local capacity = tonumber(ARGV[idxBase + 2]) | |
local nTokens = tonumber(ARGV[idxBase + 3]) | |
local timeNow = tonumber(ARGV[idxBase + 4]) | |
local expire = tonumber(ARGV[idxBase + 5]) | |
local currentTokens = -1 | |
local lastRefillTime = timeNow | |
if redis.call('exists', key) == 0 then | |
currentTokens = capacity | |
redis.call('hset', key, 'lastRefillTime', timeNow) | |
else | |
lastRefillTime = tonumber(redis.call('hget', key, 'lastRefillTime')) | |
if timeNow - lastRefillTime > interval then | |
currentTokens = capacity | |
redis.call('hset', key, 'lastRefillTime', timeNow) | |
else | |
currentTokens = tonumber(redis.call('hget', key, 'tokensRemaining')) | |
if currentTokens > capacity then | |
currentTokens = capacity | |
end | |
end | |
end | |
assert(currentTokens >= 0) | |
if expire > 0 then | |
redis.call('expire', key, expire) | |
end | |
if nTokens > currentTokens then | |
redis.call('hset', key, 'tokens', currentTokens) | |
for i, effect in ipairs(effects) do | |
redis.call('hset', effect[1], 'tokensRemaining', effect[2]) | |
end | |
return {key, interval, capacity, currentTokens, lastRefillTime} | |
else | |
table.insert(effects, {key, currentTokens, nTokens}) | |
end | |
end | |
for i, effect in ipairs(effects) do | |
redis.call('hset', effect[1], 'tokensRemaining', effect[2] - effect[3]) | |
end | |
return {'', 0, 0, 0, 0} | |
''' | |
LUA_SCRIPT_SHA1 = hashlib.sha1(LUA_SCRIPT).hexdigest() | |
LUA_SCRIPT_SMOOTH_SHA1 = hashlib.sha1(LUA_SCRIPT_SMOOTH).hexdigest() | |
class RedisConsumeDenied(object): | |
def __init__(self, redis_rv): | |
self.redis_key = redis_rv[0] | |
self.interval = redis_rv[1] / 1000 | |
self.capacity = redis_rv[2] | |
self.current_tokens = redis_rv[3] | |
self.last_fill_at = redis_rv[4] | |
def __repr__(self): | |
return '<RedisConsumeDenied([{}] interval={}, capacity={}, tokens={})>'.format( | |
self.redis_key, self.interval, self.capacity, self.current_tokens, | |
) | |
class RateLimiter(object): | |
redis_cache = get_cache('redis_cache') # CACHES['redis_cache'] in settings.py | |
def __init__(self, key_prefix=None): | |
""" | |
:param str key_prefix: | |
""" | |
self.key_prefix = key_prefix or b'rate_limiter' | |
self.redis_cli = self.redis_cache.get_client(self.key_prefix, write=True) | |
def make_key(self, key, interval): | |
return b'{}:{}:{}'.format(self.key_prefix, key, interval) | |
def now_ms(self): | |
return int(time.time() * 1000) | |
def consume_smooth(self, args): | |
""" | |
:param list[(str, float|int, int, int)] args: | |
:rtype: (bool, RedisConsumeDenied) | |
""" | |
script_keys = [] | |
script_args = [] | |
the_now_ms = self.now_ms() | |
for (key, interval, capacity, n) in args: | |
redis_key = self.make_key(key, interval) | |
expire = interval * 2 + 15 | |
interval_ms = interval * 1000 | |
interval_per_permit_ms = interval_ms / capacity # type: float | |
script_keys.append(redis_key) | |
script_args.extend([interval_ms, capacity, n, the_now_ms, expire, interval_per_permit_ms]) | |
for i in range(3): | |
try: | |
rv = self.redis_cli.evalsha( | |
LUA_SCRIPT_SMOOTH_SHA1, len(script_keys), *(script_keys + script_args) | |
) | |
if rv == ['', 0, 0, 0, 0]: | |
return True, None | |
else: | |
return False, RedisConsumeDenied(rv) | |
except RedisError: | |
sha1 = self.redis_cli.script_load(LUA_SCRIPT_SMOOTH) | |
assert sha1 == LUA_SCRIPT_SMOOTH_SHA1 | |
return True, None | |
def consume_multi(self, args): | |
""" | |
:param list[(str, float|int, int, int)] args: | |
:rtype: (bool, RedisConsumeDenied) | |
""" | |
script_keys = [] | |
script_args = [] | |
the_now_ms = self.now_ms() | |
for (key, interval, capacity, n) in args: | |
redis_key = self.make_key(key, interval) | |
expire = interval * 2 + 15 | |
interval_ms = interval * 1000 | |
script_keys.append(redis_key) | |
script_args.extend([interval_ms, capacity, n, the_now_ms, expire]) | |
for i in range(3): | |
try: | |
rv = self.redis_cli.evalsha( | |
LUA_SCRIPT_SHA1, len(script_keys), *(script_keys + script_args) | |
) | |
if rv == ['', 0, 0, 0, 0]: | |
return True, None | |
else: | |
return False, RedisConsumeDenied(rv) | |
except RedisError: | |
sha1 = self.redis_cli.script_load(LUA_SCRIPT) | |
assert sha1 == LUA_SCRIPT_SHA1 | |
return True, None | |
def consume(self, key, interval, capacity, n=1, smooth=True): | |
""" | |
:param str key: | |
:param float|int interval: | |
:param int capacity: | |
:param int n: | |
:param bool smooth: | |
:rtype: (bool, RedisConsumeDenied) | |
""" | |
if smooth: | |
return self.consume_smooth([(key, interval, capacity, n)]) | |
else: | |
return self.consume_multi([(key, interval, capacity, n)]) | |
def dump(self, key, interval): | |
""" | |
:param str key: | |
:param float|int interval: | |
""" | |
print(self.redis_cli.hgetall(self.make_key(key, interval))) | |
class RatePolicy(object): | |
""" | |
global rate limit | |
""" | |
def __init__(self, interval, capacity): | |
self.interval = interval | |
self.capacity = capacity | |
def make_key(self, request): | |
""" | |
:param request: | |
:rtype: str | |
""" | |
return request.path | |
def groups(self, request): | |
""" | |
:param request: | |
:rtype: list[(str, float|int, int, int)] | |
""" | |
return [(self.make_key(request), self.interval, self.capacity, 1)] | |
_limiter = RateLimiter() | |
def rate_limit(policies, smooth=True, limiter=None): | |
""" | |
Usage: @rate_limit([RatePolicy(1, 1)]) | |
:param list[RatePolicy] policies: | |
:param bool smooth: | |
:param RateLimiter limiter: | |
""" | |
def _decorator(func): | |
@functools.wraps(func) | |
def _func(request, *args, **kwargs): | |
limit_args = itertools.chain.from_iterable([policy.groups(request) for policy in policies]) | |
limit_args = list(limit_args) | |
li = limiter or _limiter | |
bypass, error = True, None | |
if limit_args: | |
if smooth: | |
bypass, error = li.consume_smooth(limit_args) | |
else: | |
bypass, error = li.consume_multi(limit_args) | |
if bypass: | |
return func(request, *args, **kwargs) | |
else: | |
return HttpResponse(status=429) | |
return _func | |
return _decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment