Created
May 15, 2014 07:49
-
-
Save anonymous/4dcb6892346ff17709fe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
RATE_LIMIT_SCRIPT = r''' | |
local now = tonumber(ARGV[1]) | |
local required = tonumber(ARGV[2]) | |
local rate = tonumber(ARGV[3]) | |
local per_secs = tonumber(ARGV[4]) | |
local do_subtract = tonumber(ARGV[5]) == 1 | |
local full_at = tonumber(redis.call('GET', KEYS[1])) or 0 | |
local score, result | |
if full_at < now then | |
score = rate | |
else | |
score = rate - (full_at - now) * (rate / per_secs) | |
end | |
if score < required then | |
result = false | |
score = 0 | |
else | |
result = true | |
score = score - required | |
end | |
local new_full_in = per_secs * (rate - score) / rate | |
local new_full_at = now + new_full_in | |
if do_subtract then | |
redis.call('setex', KEYS[1], math.ceil(new_full_in), new_full_at) | |
end | |
return result | |
''' | |
class RedisRateLimiter(object): | |
def __init__(self, redis, base_key, rate, per_secs): | |
self.script = None | |
self.redis = redis | |
self.base_key = base_key | |
self.rate = float(rate) | |
self.per_secs = float(per_secs) | |
def rate_limit(self, identifier, required=1.0, do_subtract=True): | |
if self.script is None: | |
self.script = self.redis.register_script(RATE_LIMIT_SCRIPT) | |
identifier = json.dumps(identifier, sort_keys=True) | |
key = self.base_key + ':' + identifier | |
return bool(self.script( | |
client=self.redis, | |
keys=[key], | |
args=[repr(time()), repr(required), repr(self.rate), repr(self.per_secs), int(bool(do_subtract))] | |
)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment