-
-
Save dvliman/ebec917b7ef1867e27eb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
RATE_LIMIT_SCRIPT = r''' | |
local now = tonumber(ARGV[1]) | |
local required = tonumber(ARGV[2]) | |
local rate = tonumber(ARGV[3]) | |
local per_secs = tonumber(ARGV[4]) | |
local do_subtract = tonumber(ARGV[5]) == 1 | |
local full_at = tonumber(redis.call('GET', KEYS[1])) or 0 | |
local score, result | |
if full_at < now then | |
score = rate | |
else | |
score = rate - (full_at - now) * (rate / per_secs) | |
end | |
if score < required then | |
result = false | |
score = 0 | |
else | |
result = true | |
score = score - required | |
end | |
local new_full_in = per_secs * (rate - score) / rate | |
local new_full_at = now + new_full_in | |
if do_subtract then | |
redis.call('setex', KEYS[1], math.ceil(new_full_in), new_full_at) | |
end | |
return result | |
''' | |
class RedisRateLimiter(object): | |
def __init__(self, redis, base_key, rate, per_secs): | |
self.script = None | |
self.redis = redis | |
self.base_key = base_key | |
self.rate = float(rate) | |
self.per_secs = float(per_secs) | |
def rate_limit(self, identifier, required=1.0, do_subtract=True): | |
if self.script is None: | |
self.script = self.redis.register_script(RATE_LIMIT_SCRIPT) | |
identifier = json.dumps(identifier, sort_keys=True) | |
key = self.base_key + ':' + identifier | |
return bool(self.script( | |
client=self.redis, | |
keys=[key], | |
args=[repr(time()), repr(required), repr(self.rate), repr(self.per_secs), int(bool(do_subtract))] | |
)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment