Created
March 5, 2015 23:20
-
-
Save sunilmallya/6cb40dc4b9761a800057 to your computer and use it in GitHub Desktop.
Wrapper class that supports both synchronous and asynchronous redis client
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
A useful DB Connection class with both sync & async redis clients | |
It uses a threadpool to make the redis library asynchronous so as | |
to work with tornado seamlessly. There is also a retry wrapper built in | |
to retry in case of connection failures to redis server | |
Tornado 4.0 required, for the rest of the requirements check the imports | |
''' | |
import concurrent.futures | |
import logging | |
import multiprocessing | |
import os | |
import redis as blockingRedis | |
import time | |
import tornado.ioloop | |
import tornado.gen | |
import tornado.httpclient | |
import tornado.web | |
from tornado.httpclient import HTTPResponse, HTTPRequest | |
from tornado.options import define, options | |
import threading | |
# testing | |
from mock import patch, MagicMock | |
import tornado.testing | |
import unittest | |
_log = logging.getLogger(__name__) | |
define("redisDB", default="127.0.0.1", help="Main DB") | |
define("dbPort", default=6379, help="Main DB") | |
define("maxRedisRetries", default=3, help="") | |
define("baseRedisRetryWait", default=5, help="") | |
class DBStateError(ValueError):pass | |
class DBConnection(object): | |
'''Connection to the database. | |
There is one connection for each object type, so to get the | |
connection, please use the get() function and don't create it | |
directly. | |
db_conn = DBConnection.get(self) OR | |
db_conn = DBConnection.get(cls) | |
''' | |
#Note: Lock for each instance, currently locks for any instance creation | |
__singleton_lock = threading.Lock() | |
_singleton_instance = {} | |
def __init__(self, class_name): | |
'''Init function. | |
DO NOT CALL THIS DIRECTLY. Use the get() function instead | |
''' | |
host = options.redisDB | |
port = options.dbPort | |
# NOTE: You can add conditionals here to connect to different | |
# redis servers if its shared on class names | |
self.conn, self.blocking_conn = RedisClient.get_client(host, port) | |
def fetch_keys_from_db(self, key_prefix, callback=None): | |
''' fetch keys that match a prefix ''' | |
if callback: | |
self.conn.keys(key_prefix, callback) | |
else: | |
keys = self.blocking_conn.keys(key_prefix) | |
return keys | |
def clear_db(self): | |
'''Erases all the keys in the database. | |
This should really only be used in test scenarios. | |
''' | |
self.blocking_conn.flushdb() | |
@classmethod | |
def update_instance(cls, cname): | |
''' Method to update the connection object in case of | |
db config update ''' | |
if cls._singleton_instance.has_key(cname): | |
with cls.__singleton_lock: | |
if cls._singleton_instance.has_key(cname): | |
cls._singleton_instance[cname] = cls(cname) | |
@classmethod | |
def get(cls, otype=None): | |
'''Gets a DB connection for a given object type. | |
otype - The object type to get the connection for. | |
Can be a class object, an instance object or the class name | |
as a string. | |
''' | |
cname = None | |
if otype: | |
if isinstance(otype, basestring): | |
cname = otype | |
else: | |
#handle the case for classmethod | |
cname = otype.__class__.__name__ \ | |
if otype.__class__.__name__ != "type" else otype.__name__ | |
if not cls._singleton_instance.has_key(cname): | |
with cls.__singleton_lock: | |
if not cls._singleton_instance.has_key(cname): | |
cls._singleton_instance[cname] = \ | |
DBConnection(cname) | |
return cls._singleton_instance[cname] | |
@classmethod | |
def clear_singleton_instance(cls): | |
''' | |
Clear the singleton instance for each of the classes | |
NOTE: To be only used by the test code | |
''' | |
cls._singleton_instance = {} | |
class RedisRetryWrapper(object): | |
'''Wraps a redis client so that it retries with exponential backoff. | |
You use this class exactly the same way that you would use the | |
StrctRedis class. | |
Calls on this object are blocking. | |
''' | |
def __init__(self, *args, **kwargs): | |
self.client = blockingRedis.StrictRedis(*args, **kwargs) | |
self.max_tries = options.maxRedisRetries | |
self.base_wait = options.baseRedisRetryWait | |
def _get_wrapped_retry_func(self, func): | |
'''Returns an blocking retry function wrapped around the given func. | |
''' | |
def RetryWrapper(*args, **kwargs): | |
cur_try = 0 | |
while True: | |
try: | |
return func(*args, **kwargs) | |
except Exception as e: | |
_log.error('Error talking to redis on attempt %i: %s' % | |
(cur_try, e)) | |
cur_try += 1 | |
if cur_try == self.max_tries: | |
raise | |
# Do an exponential backoff | |
delay = (1 << cur_try) * self.base_wait # in seconds | |
time.sleep(delay) | |
return RetryWrapper | |
def __getattr__(self, attr): | |
'''Allows us to wrap all of the redis-py functions.''' | |
if hasattr(self.client, attr): | |
if hasattr(getattr(self.client, attr), '__call__'): | |
return self._get_wrapped_retry_func( | |
getattr(self.client, attr)) | |
raise AttributeError(attr) | |
class RedisAsyncWrapper(object): | |
''' | |
Replacement class for tornado-redis | |
This is a wrapper class which does redis operation | |
in a background thread and on completion transfers control | |
back to the tornado ioloop. If you wrap this around gen/Task, | |
you can write db operations as if they were synchronous. | |
usage: | |
value = yield tornado.gen.Task(RedisAsyncWrapper().get, key) | |
#TODO: see if we can completely wrap redis-py calls, helpful if | |
you can get the callback attribue as well when call is made | |
''' | |
_thread_pools = {} | |
_pool_lock = multiprocessing.RLock() | |
_async_pool_size = 10 | |
def __init__(self, host='127.0.0.1', port=6379): | |
self.client = blockingRedis.StrictRedis(host, port, socket_timeout=10) | |
self.max_tries = options.maxRedisRetries | |
self.base_wait = options.baseRedisRetryWait | |
@classmethod | |
def _get_thread_pool(cls): | |
'''Get the thread pool for this process.''' | |
with cls._pool_lock: | |
try: | |
return cls._thread_pools[os.getpid()] | |
except KeyError: | |
pool = concurrent.futures.ThreadPoolExecutor( | |
cls._async_pool_size) | |
cls._thread_pools[os.getpid()] = pool | |
return pool | |
def _get_wrapped_async_func(self, func): | |
'''Returns an asynchronous function wrapped around the given func. | |
The asynchronous call has a callback keyword added to it | |
''' | |
def AsyncWrapper(*args, **kwargs): | |
# Find the callback argument | |
try: | |
callback = kwargs['callback'] | |
del kwargs['callback'] | |
except KeyError: | |
if len(args) > 0 and hasattr(args[-1], '__call__'): | |
callback = args[-1] | |
args = args[:-1] | |
else: | |
raise AttributeError('A callback is necessary') | |
io_loop = tornado.ioloop.IOLoop.current() | |
def _cb(future, cur_try=0): | |
if future.exception() is None: | |
callback(future.result()) | |
else: | |
_log.error('Error talking to redis on attempt %i: %s' % | |
(cur_try, future.exception())) | |
cur_try += 1 | |
if cur_try == self.max_tries: | |
raise future.exception() | |
delay = (1 << cur_try) * self.base_wait # in seconds | |
io_loop.add_timeout( | |
time.time() + delay, | |
lambda: io_loop.add_future( | |
RedisAsyncWrapper._get_thread_pool().submit( | |
func, *args, **kwargs), | |
lambda x: _cb(x, cur_try))) | |
future = RedisAsyncWrapper._get_thread_pool().submit( | |
func, *args, **kwargs) | |
io_loop.add_future(future, _cb) | |
return AsyncWrapper | |
def __getattr__(self, attr): | |
'''Allows us to wrap all of the redis-py functions.''' | |
if hasattr(self.client, attr): | |
if hasattr(getattr(self.client, attr), '__call__'): | |
return self._get_wrapped_async_func( | |
getattr(self.client, attr)) | |
raise AttributeError(attr) | |
class RedisClient(object): | |
''' | |
Static class for REDIS configuration | |
''' | |
#static variables | |
host = '127.0.0.1' | |
port = 6379 | |
client = None | |
blocking_client = None | |
def __init__(self, host='127.0.0.1', port=6379): | |
self.client = RedisAsyncWrapper(host, port) | |
self.blocking_client = RedisRetryWrapper(host, port) | |
@staticmethod | |
def get_client(host=None, port=None): | |
''' | |
return connection objects (blocking and non blocking) | |
''' | |
if host is None: | |
host = RedisClient.host | |
if port is None: | |
port = RedisClient.port | |
RedisClient.c = RedisAsyncWrapper(host, port) | |
RedisClient.bc = RedisRetryWrapper( | |
host, port, socket_timeout=10) | |
return RedisClient.c, RedisClient.bc | |
class DBObject(object): | |
''' | |
Abstract class to represent an object that is stored in the database. | |
Note: You can make this more abstract by ddding your own | |
serializers and deserializers to insert python objects in to the DB | |
and create the object on retrieval | |
# NOTE: Only get & save methods are currently implemented, but you get | |
the basic idea, right...? | |
''' | |
def __init__(self, key): | |
self.key = str(key) | |
def save(self, callback=None): | |
''' | |
Save the object to the database. | |
''' | |
db_connection = DBConnection.get(self) | |
value = self.get_value() | |
if callback: | |
db_connection.conn.set(self.key, value, callback) | |
else: | |
return db_connection.blocking_conn.set(self.key, value) | |
@classmethod | |
def get(cls, key, callback=None): | |
db_connection = DBConnection.get(cls) | |
if callback: | |
db_connection.conn.get(key, callback) | |
else: | |
data = db_connection.blocking_conn.get(key) | |
return data | |
def get_value(self): | |
raise NotImplementedError() | |
################# TESTS ########################## | |
# Lil test class | |
class Foo(DBObject): | |
''' | |
# You get the idea now, You can make the abstract class more generic | |
''' | |
def __init__(self, key, val): | |
self.value = val | |
super(Foo, self).__init__(key) | |
def get_value(self): | |
return self.value | |
################ Sanity tester ###################### | |
def test_async(): | |
''' | |
Async sanity test code | |
''' | |
def callback(val): | |
print "async ", val | |
Foo.get('test', callback) | |
tornado.ioloop.IOLoop.current().start() | |
# It'll block here, you'll need to ^C or SIGTERM to exit | |
################## Sample unit Test #################### | |
class TestHandler(tornado.web.RequestHandler): | |
def initialize(self): | |
pass | |
class TestAsyncDBConnection(tornado.testing.AsyncHTTPTestCase, | |
unittest.TestCase): | |
''' | |
more info: http://tornado.readthedocs.org/en/latest/testing.html | |
''' | |
def setUp(self): | |
super(TestAsyncDBConnection, self).setUp() | |
def tearDown(self): | |
DBConnection.clear_singleton_instance() | |
super(TestAsyncDBConnection, self).tearDown() | |
def get_app(self): | |
return tornado.web.Application([(r'/', TestHandler)]) | |
@tornado.testing.gen_test | |
def test_async_get(self): | |
foo = Foo('unittestfoo', 'bar') | |
foo.save() | |
val = yield tornado.gen.Task(Foo.get, "unittestfoo") | |
self.assertEqual(val, 'bar') | |
# .... write more tests as you please | |
if __name__ == "__main__": | |
# Assumes redis is running on localhost on 6379 | |
foo = Foo('test', 'testval') | |
foo.save() | |
val = Foo.get('test') | |
assert val == 'testval' | |
# Feel free to write more test code and play with this gist as necessary | |
# The main motive here is to provide a framework to work with | |
# sanity test (async) | |
#test_async() | |
# Unit test | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment