-
-
Save josiahcarlson/80584b49da41549a7d5c to your computer and use it in GitHub Desktop.
''' | |
rate_limit2.py | |
Copyright 2014, Josiah Carlson - [email protected] | |
Released under the MIT license | |
This module intends to show how to perform standard and sliding-window rate | |
limits as a companion to the two articles posted on Binpress entitled | |
"Introduction to rate limiting with Redis", parts 1 and 2: | |
http://www.binpress.com/tutorial/introduction-to-rate-limiting-with-redis/155 | |
http://www.binpress.com/tutorial/introduction-to-rate-limiting-with-redis-part-2/166 | |
... which will (or have already been) reposted on my personal blog at least 2 | |
weeks after their original binpress.com posting: | |
http://www.dr-josiah.com | |
''' | |
import json | |
import time | |
from flask import g, request | |
def get_identifiers(): | |
ret = ['ip:' + request.remote_addr] | |
if g.user.is_authenticated(): | |
ret.append('user:' + g.user.get_id()) | |
return ret | |
def over_limit(conn, duration=3600, limit=240): | |
bucket = ':%i:%i'%(duration, time.time() // duration) | |
for id in get_identifiers(): | |
key = id + bucket | |
count = conn.incr(key) | |
conn.expire(key, duration) | |
if count > limit: | |
return True | |
return False | |
def over_limit_multi(conn, limits=[(1, 10), (60, 120), (3600, 240)]): | |
for duration, limit in limits: | |
if over_limit(conn, duration, limit): | |
return True | |
return False | |
def over_limit(conn, duration=3600, limit=240): | |
# Replaces the earlier over_limit() function and reduces round trips with | |
# pipelining. | |
pipe = conn.pipeline(transaction=True) | |
bucket = ':%i:%i'%(duration, time.time() // duration) | |
for id in get_identifiers(): | |
key = id + bucket | |
pipe.incr(key) | |
pipe.expire(key, duration) | |
if pipe.execute()[0] > limit: | |
return True | |
return False | |
def over_limit_multi_lua(conn, limits=[(1, 10), (60, 120), (3600, 240)]): | |
if not hasattr(conn, 'over_limit_lua'): | |
conn.over_limit_lua = conn.register_script(over_limit_multi_lua_) | |
return conn.over_limit_lua( | |
keys=get_identifiers(), args=[json.dumps(limits), time.time()]) | |
over_limit_multi_lua_ = ''' | |
local limits = cjson.decode(ARGV[1]) | |
local now = tonumber(ARGV[2]) | |
for i, limit in ipairs(limits) do | |
local duration = limit[1] | |
local bucket = ':' .. duration .. ':' .. math.floor(now / duration) | |
for j, id in ipairs(KEYS) do | |
local key = id .. bucket | |
local count = redis.call('INCR', key) | |
redis.call('EXPIRE', key, duration) | |
if tonumber(count) > limit[2] then | |
return 1 | |
end | |
end | |
end | |
return 0 | |
''' | |
def over_limit_sliding_window(conn, weight=1, limits=[(1, 10), (60, 120), (3600, 240, 60)], redis_time=False): | |
if not hasattr(conn, 'over_limit_sliding_window_lua'): | |
conn.over_limit_sliding_window_lua = conn.register_script(over_limit_sliding_window_lua_) | |
now = conn.time()[0] if redis_time else time.time() | |
return conn.over_limit_sliding_window_lua( | |
keys=get_identifiers(), args=[json.dumps(limits), now, weight]) | |
over_limit_sliding_window_lua_ = ''' | |
local limits = cjson.decode(ARGV[1]) | |
local now = tonumber(ARGV[2]) | |
local weight = tonumber(ARGV[3] or '1') | |
local longest_duration = limits[1][1] or 0 | |
local saved_keys = {} | |
-- handle cleanup and limit checks | |
for i, limit in ipairs(limits) do | |
local duration = limit[1] | |
longest_duration = math.max(longest_duration, duration) | |
local precision = limit[3] or duration | |
precision = math.min(precision, duration) | |
local blocks = math.ceil(duration / precision) | |
local saved = {} | |
table.insert(saved_keys, saved) | |
saved.block_id = math.floor(now / precision) | |
saved.trim_before = saved.block_id - blocks + 1 | |
saved.count_key = duration .. ':' .. precision .. ':' | |
saved.ts_key = saved.count_key .. 'o' | |
for j, key in ipairs(KEYS) do | |
local old_ts = redis.call('HGET', key, saved.ts_key) | |
old_ts = old_ts and tonumber(old_ts) or saved.trim_before | |
if old_ts > now then | |
-- don't write in the past | |
return 1 | |
end | |
-- discover what needs to be cleaned up | |
local decr = 0 | |
local dele = {} | |
local trim = math.min(saved.trim_before, old_ts + blocks) | |
for old_block = old_ts, trim - 1 do | |
local bkey = saved.count_key .. old_block | |
local bcount = redis.call('HGET', key, bkey) | |
if bcount then | |
decr = decr + tonumber(bcount) | |
table.insert(dele, bkey) | |
end | |
end | |
-- handle cleanup | |
local cur | |
if #dele > 0 then | |
redis.call('HDEL', key, unpack(dele)) | |
cur = redis.call('HINCRBY', key, saved.count_key, -decr) | |
else | |
cur = redis.call('HGET', key, saved.count_key) | |
end | |
-- check our limits | |
if tonumber(cur or '0') + weight > limit[2] then | |
return 1 | |
end | |
end | |
end | |
-- there is enough resources, update the counts | |
for i, limit in ipairs(limits) do | |
local saved = saved_keys[i] | |
for j, key in ipairs(KEYS) do | |
-- update the current timestamp, count, and bucket count | |
redis.call('HSET', key, saved.ts_key, saved.trim_before) | |
redis.call('HINCRBY', key, saved.count_key, weight) | |
redis.call('HINCRBY', key, saved.count_key .. saved.block_id, weight) | |
end | |
end | |
-- We calculated the longest-duration limit so we can EXPIRE | |
-- the whole HASH for quick and easy idle-time cleanup :) | |
if longest_duration > 0 then | |
for _, key in ipairs(KEYS) do | |
redis.call('EXPIRE', key, longest_duration) | |
end | |
end | |
return 0 | |
''' |
How would you return an actual timestamp instead of 1
to be used in a Retry-After
header?
The answer for you @ciokan is you need to modify the Lua script to calculate the delay. Right now it just returns whether you need to wait. https://gist.github.com/josiahcarlson/80584b49da41549a7d5c#file-rate_limit2-py-L157 is the line you are looking for.
Hi I have three questions.
Question 1
In over_limit_sliding_window_lua_, should
if old_ts > now then
at here be
if old_ts > saved.block_id then
because old_ts
is the oldest block id, not a timestamp?
Question 2
Should
local trim = math.min(saved.trim_before, old_ts + blocks)
at here be
saved.trim_before = math.min(saved.trim_before, old_ts + blocks)
because later when saving the oldest block id the code uses saved.trim_before
redis.call('HSET', key, saved.ts_key, saved.trim_before)
?
Question 3
Is the purpose of the code
local trim = math.min(saved.trim_before, old_ts + blocks)
at here to limit the number of blocks to trim to be at most blocks
?
How would you return an actual timestamp instead of
1
to be used in aRetry-After
header?
Replace line 157 (return 1) with the below code. We are trying to loop through the present duration blocks and find out the earliest block with a request made and then calculate the time until that request block would become stall and thus allows for new request.
-- return 1
local last_attempt
for last_block = saved.trim_before, saved.block_id, precision do
local bcount = redis.call('HGET', key, saved.count_key .. last_block)
if (bcount) then
last_attempt = last_block
break
end
end
local next_attempt
if last_attempt then
next_attempt = (last_attempt + blocks) * precision
else
next_attempt = 0
end
return next_attempt
Note: The next_attempt received is UNIX timestamp in seconds and not milliseconds
@josiahcarlson Please review this code for any improvement or bug
Would be nice if
over_limit_sliding_window_lua
returned which limit was in effect, useful for different actions on different limits ("require captcha for this limit, reject on that limit"). For this you can just returni
instead of1
.