Created
October 6, 2014 09:29
-
-
Save lgelo/164c97bd3b6d9f2d1893 to your computer and use it in GitHub Desktop.
Simple rate limiter based on token bucket algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from time import time | |
from threading import Lock | |
class RateLimiter: | |
""" | |
An implementation of rate limiter based on token bucket algorithm. | |
Original implementation: | |
http://code.activestate.com/recipes/578659-python-3-token-bucket-rate-limit/ | |
""" | |
def __init__(self): | |
self.buckets = dict() | |
self.rate = 0 | |
self.ttl = 0 | |
self.lock = Lock() | |
def set_rate(self, rate): | |
""" | |
Set maximum rate per second. | |
""" | |
with self.lock: | |
self.rate = rate | |
def set_ttl(self, ttl): | |
""" | |
Set TTL for bucket. Buckets that are unused for at least TTL will be | |
deleted when self.expire() is called. | |
""" | |
with self.lock: | |
self.ttl = ttl | |
def grant(self, key, tokens=1): | |
""" | |
Returns True if there's enough tokens in a bucket specified by key to grant | |
requrested number of tokens. Otherwise returns False. | |
""" | |
with self.lock: | |
if not self.rate: | |
return 0 | |
now = time() | |
if key not in self.buckets.keys(): | |
self.buckets[key]=dict() | |
self.buckets[key]['tokens'] = self.rate | |
self.buckets[key]['last'] = now | |
lapse = now - self.buckets[key]['last'] | |
self.buckets[key]['last'] = now | |
self.buckets[key]['tokens'] += lapse * self.rate | |
if self.buckets[key]['tokens'] > self.rate: | |
self.buckets[key]['tokens'] = self.rate | |
if self.buckets[key]['tokens'] >= tokens: | |
self.buckets[key]['tokens'] -= tokens | |
return True | |
else: | |
return False | |
def expire(self): | |
""" | |
Delete expired buckets. | |
""" | |
with self.lock: | |
if not self.ttl: | |
return 0 | |
old = time() - self.ttl | |
expired = list( k for k, v in self.buckets.iteritems() if v['last'] < old ) | |
map(self.buckets.__delitem__, expired) | |
return len(expired) | |
if __name__ == '__main__': | |
import sys | |
from time import sleep | |
rl = RateLimiter() | |
rl.set_rate(5) | |
rl.set_ttl(5) | |
rl.grant('one-shot-client') | |
for _ in range(100): | |
if not rl.grant('client-1'): | |
print "sleep" | |
sleep(1) | |
else: | |
print ".", | |
expired = rl.expire() | |
if expired: | |
print "expired buckets:", expired | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice Work!