Skip to content

Instantly share code, notes, and snippets.

@ptarjan
Created March 29, 2017 20:44
Show Gist options
  • Save ptarjan/e38f45f2dfe601419ca3af937fff574d to your computer and use it in GitHub Desktop.
Save ptarjan/e38f45f2dfe601419ca3af937fff574d to your computer and use it in GitHub Desktop.

Scaling your API with rate limiters

The following are examples of the four types rate limiters discussed in the accompanying blog post. In the examples below I've used pseudocode-like Ruby, so if you're unfamiliar with Ruby you should be able to easily translate this approach to other languages. Complete examples in Ruby are also provided later in this gist.

In most cases you'll want all these examples to be classes, but I've used simple functions here to keep the code samples brief.

Request rate limiter

This uses a basic token bucket algorithm and relies on the fact that Redis scripts execute atomically. No other operations can run between fetching the count and writing the new count.

The full script with a small test suite is available, but here is a sketch:

# How many requests per second do you want a user to be allowed to do?
REPLENISH_RATE = 100

# How much bursting do you want to allow?
CAPACITY = 5 * REPLENISH_RATE

SCRIPT = File.read('request_rate_limiter.lua')

def check_request_rate_limiter(user)
  # Make a unique key per user.
  prefix = 'request_rate_limiter.' + user

  # You need two Redis keys for Token Bucket.
  keys = [prefix + '.tokens', prefix + '.timestamp']

  # The arguments to the LUA script. time() returns unixtime in seconds.
  args = [REPLENISH_RATE, CAPACITY, Time.new.to_i, 1]

  begin
    allowed, tokens_left = redis.eval(SCRIPT, keys, args)
  rescue RedisError => e
    # Fail open. We don't want a hard dependency on Redis to allow traffic.
    # Make sure to set an alert so you know if this is happening too much.
    # Our observed failure rate is 0.01%.
    puts 'Redis failed: ' + e
    return
  end

  if !allowed
    raise RateLimitError.new(status_code = 429)
  end
end

Here is the corresponding request_rate_limiter.lua script:

local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]

local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])

local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)

local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
  last_tokens = capacity
end

local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
  last_refreshed = 0
end

local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
if allowed then
  new_tokens = filled_tokens - requested
end

redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)

return { allowed, new_tokens }

Concurrent requests limiter

Because Redis is so fast, doing the naive thing works. Just add a random token to a set at the start of a request and remove it from the set when you're done. If the set is too large, reject the request.

Again the full code is available and a sketch follows:

# The maximum length a request can take
TTL = 60

# How many concurrent requests a user can have going at a time
CAPACITY = 100

SCRIPT = File.read('concurrent_requests_limiter.lua')

class ConcurrentRequestLimiter
  def check(user)
    @timestamp = Time.new.to_i

    # A string of some random characters. Make it long enough to make sure two machines don't have the same string in the same TTL.
    id = Random.new.bytes(4)
    key = 'concurrent_requests_limiter.' + user
    begin
      # Clear out old requests that probably got lost
      redis.zremrangebyscore(key, '-inf', @timestamp - TTL)
      keys = [key]
      args = [CAPACITY, @timestamp, id]
      allowed, count = redis.eval(SCRIPT, keys, args)
    rescue RedisError => e
      # Similarly to above, remember to fail open so Redis outages don't take down your site
      log.info('Redis failed: ' + e)
      return
    end

    if allowed
      # Save it for later so we can remove it when the request is done
      @id_in_redis = id
    else
      raise RateLimitError.new(status_code: 429)
    end
  end

  # Call this method after a request finishes
  def post_request_bookkeeping(user)
    if not @id_in_redis
      return
    end
    key = 'concurrent_requests_limiter.' + user
    removed = redis.zrem(key, @id_in_redis)
  end

  def do_request(user)
    check(user)

    # Do the actual work here

    post_request_bookkeeping(user)
  end
end

The content of concurrent_requests_limiter.lua is simple and is meant to guarantee the atomicity of the ZCARD and ZADD.

local key = KEYS[1]

local capacity = tonumber(ARGV[1])
local timestamp = tonumber(ARGV[2])
local id = ARGV[3]

local count = redis.call("zcard", key)
local allowed = count < capacity

if allowed then
  redis.call("zadd", key, timestamp, id)
end

return { allowed, count }

Fleet usage load shedder

We can now move from preventing abuse to adding stability to your site with load shedders. If you can categorize traffic into buckets where no fewer than X% of your workers should be available to process high-priority traffic, then you're in luck: this type of algorithm can help. We call it a load shedder instead of a rate limiter because it isn't trying to reduce the rate of a specific user's requests. Instead, it adds backpressure so internal systems can recover.

When this load shedder kicks in it will start dropping non-critical traffic. There should be alarm bells ringing and people should be working to get the traffic back, but at least your core traffic will work. For Stripe, high-priority traffic has to do with creating charges and moving money around, and low-priority traffic has to do with analytics and reporting.

The great thing about this load shedder is that its implementation is identical to the Concurrent Requests Limiter, except you don't use a user-specific key, you just use a global key.

limiter = ConcurrentRequestLimiter.new
def check_fleet_usage_load_shedder
  if is_high_priority_request
    return
  end

  begin
    return limiter.do_request('fleet_usage_load_shedder')
  rescue RateLimitError
    raise RateLimitError.new(status_code: 503)
  end
end

Worker utilization load shedder

This load shedder is the last resort, and only kicks in when a machine is under heavy pressure and needs to offload. The code for determining how many workers are in use is dependent on your infrastructure. The general outline is to figure out some measure of "Is our infrastructure currently failing?" If that function returns something non-zero, start throwing out your least important requests (after waiting a short period to allow imprecise measurements) with higher and higher probability. After a period of time doing that, move on to more requests until you are throwing out everything except for the most critical traffic.

The most important behavior for this load shedder is to slowly take action. Don't start throwing out traffic until your infrastructure has been sad for quite a while (30 seconds), and don't instantaneously add traffic back. Sharp changes in shedding amounts will cause wild swings and lead to failure modes that are hard to diagnose.

As before, the full script with a small test suite is available, and here is a sketch:

END_OF_GOOD_UTILIZATION = 0.7
START_OF_BAD_UTILIZATION = 0.8

# Assuming a sample rate of 8 seconds, so 28 == 2.5 * 8 == guaranteed 3 samples
NUMBER_OF_SECONDS_BEFORE_SHEDDING_STARTS = 28
NUMBER_OF_SECONDS_TO_SHED_ALL_TRAFFIC = 120

RESTING_SHED_AMOUNT = -NUMBER_OF_SECONDS_BEFORE_SHEDDING_STARTS / NUMBER_OF_SECONDS_TO_SHED_ALL_TRAFFIC

@shedding_amount_last_changed = 0
@shedding_amount = 0

def check_worker_utilization_load_shedder
  chance = drop_chance(current_worker_utilization)
  if chance == 0
    dropped = false
  else
    dropped = Random.rand() < chance
  end
  if dropped
    raise RateLimitError.new(status_code: 503)
  end
end

def drop_chance(utilization)
  update_shedding_amount_derivative(utilization)
  how_much_traffic_to_shed
end

def update_shedding_amount_derivative(utilization)
  # A number from -1 to 1
  amount = 0

  # Linearly reduce shedding
  if utilization < END_OF_GOOD_UTILIZATION
    amount = utilization / END_OF_GOOD_UTILIZATION - 1
  # A dead zone
  elsif utilization < START_OF_BAD_UTILIZATION
    amount = 0
  # Shed traffic
  else
    amount = (utilization - START_OF_BAD_UTILIZATION) / (1 - START_OF_BAD_UTILIZATION)
  end

  # scale the derivative so we take time to shed all the traffic
  @shedding_amount_derivative = clamp(amount, -1, 1) / NUMBER_OF_SECONDS_TO_SHED_ALL_TRAFFIC
end

def how_much_traffic_to_shed
  now = Time.now().to_f
  seconds_since_last_math = clamp(now - @shedding_amount_last_changed, 0, NUMBER_OF_SECONDS_BEFORE_SHEDDING_STARTS)
  @shedding_amount_last_changed = now
  @shedding_amount += seconds_since_last_math * @shedding_amount_derivative
  @shedding_amount = clamp(@shedding_amount, RESTING_SHED_AMOUNT, 1)
end

def current_worker_utilization
  # Returns a double from 0 to 1.
  # 1 means every process is busy, .5 means 1/2 the processes are working, and 0 means the machine is servicing 0 requests
  # This is infra dependent on how to read this value
end

def clamp(val, min, max)
  if val < min
    return min
  elsif val > max
    return max
  else
    return val
  end
end
require_relative 'shared'
# How many requests per second do you want a user to be allowed to do?
REPLENISH_RATE = 100
# How much bursting do you want to allow?
CAPACITY = 5 * REPLENISH_RATE
SCRIPT = File.read('request_rate_limiter.lua')
def check_request_rate_limiter(user)
# Make a unique key per user.
prefix = 'request_rate_limiter.' + user
# You need two Redis keys for Token Bucket.
keys = [prefix + '.tokens', prefix + '.timestamp']
# The arguments to the LUA script. time() returns unixtime in seconds.
args = [REPLENISH_RATE, CAPACITY, Time.new.to_i, 1]
begin
allowed, tokens_left = redis.eval(SCRIPT, keys, args)
rescue RedisError => e
# Fail open. We don't want a hard dependency on Redis to allow traffic.
# Make sure to set an alert so you know if this is happening too much.
# Our failure rate is 0.01%.
puts 'Redis failed: ' + e
return
end
if !allowed
raise RateLimitError.new(status_code: 429)
end
end
def test_check_request_rate_limiter
id = Random.rand(1000000).to_s
# Burts work
for i in 0..CAPACITY-1
check_request_rate_limiter(id)
end
begin
check_request_rate_limiter(id)
raise "it didn't throw :("
rescue RateLimitError
puts "it correctly threw"
end
sleep 1
# After the burst is done, check the steady state
for i in 0..REPLENISH_RATE-1
check_request_rate_limiter(id)
end
begin
check_request_rate_limiter(id)
raise "it didn't throw :("
rescue RateLimitError
puts "it correctly threw"
end
end
test_check_request_rate_limiter
require_relative 'shared'
# The maximum length a request can take
TTL = 60
# How many concurrent requests a user can have going at a time
CAPACITY = 100
SCRIPT = File.read('concurrent_requests_limiter.lua')
def check_concurrent_requests_limiter(user)
@timestamp = Time.new().to_i
# A string of some random characters. Make it long enough to make sure two machines don't have the same string in the same TTL.
id = Random.new.bytes(4)
key = 'concurrent_requests_limiter.' + user
begin
# Clear out old requests that probably got lost
redis.zremrangebyscore(key, '-inf', @timestamp - TTL)
keys = [key]
args = [CAPACITY, @timestamp, id]
allowed, count = redis.eval(SCRIPT, keys, args)
rescue RedisError => e
# Similarly to above, remember to fail open so Redis outages don't take down your site
log.info('Redis failed: ' + e)
return
end
if allowed
# Save it for later so we can remove it when the request is done
@id_in_redis = id
else
raise RateLimitError.new(status_code: 429)
end
end
# Call this method after a request finishes
def post_request_bookkeeping(user)
if not @id_in_redis
return
end
key = 'concurrent_requests_limiter.' + user
removed = redis.zrem(key, @id_in_redis)
end
def do_request(user)
check_concurrent_requests_limiter(user)
# Do the actual work here
post_request_bookkeeping(user)
end
def test_check_concurrent_requests_limiter
id = Random.rand(1000000).to_s
# Pounding the server is fine as long as you finish the request
for i in 0..CAPACITY*10
do_request(id)
end
# But concurrent is not
for i in 0..CAPACITY-1
check_concurrent_requests_limiter(id)
end
begin
check_concurrent_requests_limiter(id)
raise "it didn't work"
rescue
puts "it worked"
end
end
test_check_concurrent_requests_limiter
require_relative 'shared'
END_OF_GOOD_UTILIZATION = 0.7
START_OF_BAD_UTILIZATION = 0.8
# Assuming a sample rate of 8 seconds, so 28 == 2.5 * 8 == guaranteed 3 samples
NUMBER_OF_SECONDS_BEFORE_SHEDDING_STARTS = 28
NUMBER_OF_SECONDS_TO_SHED_ALL_TRAFFIC = 120
RESTING_SHED_AMOUNT = -NUMBER_OF_SECONDS_BEFORE_SHEDDING_STARTS / NUMBER_OF_SECONDS_TO_SHED_ALL_TRAFFIC
@shedding_amount_last_changed = 0
@shedding_amount = 0
def check_worker_utilization_load_shedder
chance = drop_chance(current_worker_utilization)
if chance == 0
dropped = false
else
dropped = Random.rand() < chance
end
if dropped
raise RateLimitError.new(status_code: 503)
end
end
def drop_chance(utilization)
update_shedding_amount_derivative(utilization)
how_much_traffic_to_shed
end
def update_shedding_amount_derivative(utilization)
# A number from -1 to 1
amount = 0
# Linearly reduce shedding
if utilization < END_OF_GOOD_UTILIZATION
amount = utilization / END_OF_GOOD_UTILIZATION - 1
# A dead zone
elsif utilization < START_OF_BAD_UTILIZATION
amount = 0
# Shed traffic
else
amount = (utilization - START_OF_BAD_UTILIZATION) / (1 - START_OF_BAD_UTILIZATION)
end
# scale the derivative so we take time to shed all the traffic
@shedding_amount_derivative = clamp(amount, -1, 1) / NUMBER_OF_SECONDS_TO_SHED_ALL_TRAFFIC
end
def how_much_traffic_to_shed
now = Time.now().to_f
seconds_since_last_math = clamp(now - @shedding_amount_last_changed, 0, NUMBER_OF_SECONDS_BEFORE_SHEDDING_STARTS)
@shedding_amount_last_changed = now
@shedding_amount += seconds_since_last_math * @shedding_amount_derivative
@shedding_amount = clamp(@shedding_amount, RESTING_SHED_AMOUNT, 1)
end
def current_worker_utilization
# Returns a double from 0 to 1.
# 1 means every process is busy, .5 means 1/2 the processes are working, and 0 means the machine is servicing 0 requests
@current_worker_utilization # For easy stubbing in the test example
end
def clamp(val, min, max)
if val < min
return min
elsif val > max
return max
else
return val
end
end
def test_check_worker_utilization_load_shedder
# Business as usual
@current_worker_utilization = 0
for i in (0..1000)
check_worker_utilization_load_shedder
end
# Workers are exhausted
@current_worker_utilization = 1
shed_count = 0
for i in (0..NUMBER_OF_SECONDS_BEFORE_SHEDDING_STARTS + NUMBER_OF_SECONDS_TO_SHED_ALL_TRAFFIC)
begin
check_worker_utilization_load_shedder
rescue RateLimitError
shed_count += 1
end
sleep 1
end
puts "#{shed_count} requests were dropped" # Should be ~60
# Should be shedding all traffic
begin
check_worker_utilization_load_shedder
raise "it didn't work"
rescue RateLimitError
puts "it worked"
end
end
test_check_worker_utilization_load_shedder
local key = KEYS[1]
local capacity = tonumber(ARGV[1])
local timestamp = tonumber(ARGV[2])
local id = ARGV[3]
local count = redis.call("zcard", key)
local allowed = count < capacity
if allowed then
redis.call("zadd", key, timestamp, id)
end
return { allowed, count }
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
if allowed then
new_tokens = filled_tokens - requested
end
redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)
return { allowed, new_tokens }
require 'redis'
class RateLimitError < RuntimeError; end
# In the real world this would be a class not a global variable
def redis
$_redis ||= Redis.new
end
@h-no
Copy link

h-no commented Dec 1, 2021

@ptarjan is there a license on this?

@flexoid
Copy link

flexoid commented Apr 8, 2022

@ptarjan another call for a license, please 🙂

@syedfaisal3
Copy link

syedfaisal3 commented Feb 22, 2023

@ptarjan @klboke @tonyjiang or anyone else - can you please explain how the ttl formula has been arrived at?

local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)

@minyakonga
Copy link

@ptarjan @klboke @tonyjiang or anyone else - can you please explain how the ttl formula has been arrived at?

local fill_time = capacity/rate local ttl = math.floor(fill_time*2)

i think this ttl is used to prevent traffic spike
if here don't have ttl, then
1st, at start time1, you request one token, then allowed is true, and the tokens in current bucket is 399 (if rate is 100, capacity is 400)
2nd, after 32 seconds later, time1 + 32: you request 399 requests, then allowed is true and the tokens in current bucket is 1

actually at second step, the actual QPS is 399 requests/second or higher, which is far more greater than 100 requests/second.

@minyakonga
Copy link

@syedfaisal3 damn, i think the ttl is useless, you can remove it.

@shyakadev
Copy link

@ptarjan thank you for such amazing stuff

@mehuled
Copy link

mehuled commented Dec 19, 2023

could not understand why is ttl being set as 2 times of (replenish_rate/capacity).

tokens_key = "abc.key"
timestamp_key = "abc.timestamp"
rate = 100
capacity = 500
fill_time = 5
now = 2 -- in unix seconds
ttl = 10

-- After first request
-- In Redis
-- {abc.key} =  499 -- expires at 12
-- {abc.timestamp} =  2 -- expires at 12

-- We got 500 requests at timestamp = 2
-- From now till time = 12 in unix seconds
-- In Redis
abc.key =  0 -- expires 12
abc.timestamp =  2 -- expires 12

-- Means for next 10 seconds we cannot have any request which effectively means RPS is 500/10 = 50 RPS

@runner112-113
Copy link

runner112-113 commented Mar 18, 2024

I could not understand math.floor(fill_time*2):
if i consume fill_token in first fill_time,the second fill_time will always could not get tokens;
if we set ratelimiter per second,i try record current second we have consume how much token and request token time we have produce how many token ,we can compre with this two value to judge can allow,just list follow code:

redis.replicate_commands()

local used_tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)

local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = redis.call('TIME')[1]
local requested = tonumber(ARGV[4])

local fill_time = capacity/rate
--local ttl = math.floor(fill_time*2)

--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. now)
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)

local current_second_used_tokens = tonumber(redis.call("get", used_tokens_key))
if current_second_used_tokens == nil then
  current_second_used_tokens = 0
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)

local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
  last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
if now-last_refreshed > 1000 then
  last_refreshed = math.floor((now - last_refreshed) / 1000) * 1000
end
local delta = math.max(0, now-last_refreshed)
local valid_tokens = math.min(capacity, (delta*rate) - current_second_used_tokens)
local allowed = valid_tokens >= requested

local allowed_num = 0
if allowed then
  current_second_used_tokens = current_second_used_tokens + requested
  allowed_num = 1
end

local ttl = 1000 - (now - last_refreshed)
--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)

if ttl > 0 then
  redis.call("setex", used_tokens_key, ttl, current_second_used_tokens)
  redis.call("setex", last_refreshed, ttl, last_refreshed)
end

-- return { allowed_num, new_tokens, capacity, filled_tokens, requested, new_tokens }
return { allowed_num }

@lalit10368
Copy link

lalit10368 commented Jul 26, 2024

@ptarjan Why the op to remove outdated timestamps i.e.redis.zremrangebyscore(key, '-inf', @timestamp - TTL) not included in the concurrent_requests_limiter.lua script itself ?

@Cornelius-Adeyemi
Copy link

Thanks for the well-written article. You made it very easy to understand.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment