Skip to content

Instantly share code, notes, and snippets.

Created December 10, 2012 19:19
Show Gist options
  • Save anonymous/4252677 to your computer and use it in GitHub Desktop.
Save anonymous/4252677 to your computer and use it in GitHub Desktop.
Multi node redis library
'''
Created on Jun 25, 2012
@author: andres.rangel
'''
import logging
import os
import re
from redis.client import Redis, BasePipeline, ConnectionError, StrictRedis, Script, NoScriptError
from redis.connection import ConnectionPool, Connection, DefaultParser
from random import choice
from itertools import chain
logger = logging.getLogger('haredispy.redis_multiple')
_write_commands = set(['DEL',
'FLUSHALL',
'FLUSHDB',
'SAVE',
'APPEND',
'DECRBY',
'EXPIRE',
'EXPIREAT',
'INCRBY',
'MSET',
'MSETNX',
'PERSIST',
'RENAME',
'RENAMENX',
'SET',
'SETBIT',
'SETEX',
'SETNX',
'SETRANGE',
'BLPOP',
'BRPOP',
'BRPOPLPUSH',
'LINSERT',
'LPOP',
'LPUSH',
'LPUSHX',
'LREM',
'LSET',
'LTRIM',
'RPOP',
'RPOPLPUSH',
'RPUSH',
'RPUSHX',
'SORT',
'SADD',
'SDIFFSTORE',
'SINTERSTORE',
'SMOVE',
'SPOP',
'SREM',
'SUNIONSTORE',
'ZADD',
'ZINCRBY',
'ZINTERSTORE',
'ZREM',
'ZREMRANGEBYSCORE',
'ZREMRANGEBYRANK',
'ZUNIONSTORE',
'HDEL',
'HINCRBY',
'HSET',
'HSETNX',
'HMSET',
'GETSET',
'EVAL',
'EVALSHA',
'SCRIPT',
'WATCH'
])
class WriteAwarePipeline(BasePipeline):
def __init__(self, connection_pool,
response_callbacks,
transaction,
shard_hint):
super(WriteAwarePipeline, self).__init__(connection_pool,
response_callbacks,
transaction,
shard_hint)
self.write_flag = False
def __enter__(self):
return self
def execute_command(self, *args, **kwargs):
"""
If the command is a write command, will set the
write flag to True
"""
command = args[0]
readonly = False
if kwargs:
readonly = kwargs.get("readonly", False)
if command and command in _write_commands:
#Added flag for scripts that will execute on the master or the slave.
#This write flag is true only if it's a script and the readonly is False.
#If it's not a script, then the readonly would not be set and will default to false
if not readonly:
self.write_flag = True
return super(WriteAwarePipeline, self).execute_command(*args, **kwargs)
def reset(self):
"""
resets the write flag to false
"""
try:
super(WriteAwarePipeline, self).reset()
self.connection = None
except Exception, e:
logger.error("Got exception %s while trying to reset connection. Ignoring ..." % e)
finally:
self.write_flag = False
def execute(self, raise_on_error=True):
"""
Execute all the commands in the current pipeline
Must rewrite the execute command to add more info to the
get connection pool command
"""
# if self.scripts:
# self.load_scripts()
# stack = self.command_stack
# if self.transaction or self.explicit_transaction:
# execute = self._execute_transaction
# else:
# execute = self._execute_pipeline
options = {"write": self.write_flag}
self.connection = self.connection_pool.get_connection('MULTI', self.shard_hint, **options)
# assign to self.connection so reset() releases the connection
# back to the pool after we're done
# self.connection = conn
try:
return super(WriteAwarePipeline, self).execute(raise_on_error)
except ConnectionError:
self.connection_pool.kill_connection(self.connection)
self.connection = None
# if we were watching a variable, the watch is no longer valid since
# this connection has died. raise a WatchError, which indicates
# the user should retry his transaction. If this is more than a
# temporary failure, the WATCH that the user next issue will fail,
# propgating the real ConnectionError
# if self.watching:
# raise WatchError("A ConnectionError occurred on while watching "
# "one or more keys")
# otherwise, it's safe to retry since the transaction isn't
# predicated on any state
self.connection = self.connection_pool.get_new_connection('MULTI', self.shard_hint, **options)
# return execute(self.connection, stack, raise_on_error)
return super(WriteAwarePipeline, self).execute(raise_on_error)
# finally:
# self._reset()
def pipeline_execute_command(self, *args, **options):
"""
Stage a command to be executed when execute() is next called
Returns the current Pipeline object back so commands can be
chained together, such as:
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
At some other point, you can then run: pipe.execute(),
which will execute all commands queued in the pipe.
"""
self.command_stack.append((args, options))
return self
def evalsha(self, sha, numkeys, *keys_and_args, **options):
"""
Use the ``sha`` to execute a LUA script already registered via EVAL
or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the
key names and argument values in ``keys_and_args``. Returns the result
of the script.
In practice, use the object returned by ``register_script``. This
function exists purely for Redis API completion.
"""
return self.execute_command('EVALSHA', sha, numkeys, *keys_and_args, **options)
def script_load(self, script, **options):
"Load a LUA ``script`` into the script cache. Returns the SHA."
options.update({'parse': 'LOAD'})
return self.execute_command('SCRIPT', 'LOAD', script, **options)
def script_load_all(self, script, **options):
"Load a LUA ``script`` into the script cache. Returns the SHA."
pool = self.connection_pool
if isinstance(pool, HighAvailabilityConnectionPool):
#if it has slaves, register the script on them too.
slaves = pool.get_slaves()
options.update({'parse': 'LOAD'})
args = ['SCRIPT', 'LOAD', script]
#Iterate over all the slaves, and register the script
for connection_key in slaves:
connection = pool.obtain_specified_connection_read(connection_key)
try:
connection.send_command(*args)
rsp = self.parse_response(connection, 'SCRIPT', **options)
logger.debug("obtained: {0} ".format(rsp))
pool.release(connection)
except ConnectionError, c:
pool.kill_connection(connection)
logger.error("Got exception:%s" % c)
return super(WriteAwarePipeline, self).script_load(script)
class WriteAwareStrictPipeline(WriteAwarePipeline, StrictRedis):
"Pipeline for the StrictRedis class"
pass
class WriteAwareRedisPipeline(WriteAwarePipeline, Redis):
"Pipeline for the Redis class"
pass
class HighAvailabilityStrictRedis(StrictRedis):
"""
This strict redis client
"""
def pipeline(self, transaction=False, shard_hint=None):
"""
Returns the write aware pipeline
"""
return WriteAwareStrictPipeline(
self.connection_pool,
self.response_callbacks,
transaction,
shard_hint)
def __str__(self):
return "%s" % self.connection_pool
def slowlog_get(self, value=-1):
"gets the slowlog"
if value > 0:
return self.execute_command('SLOWLOG', 'GET', value)
else:
return self.execute_command('SLOWLOG', 'GET')
def slowlog_len(self):
"gets the slowlog length"
return self.execute_command('SLOWLOG', 'LEN')
def slowlog_reset(self):
"gets the slowlog reset"
return self.execute_command('SLOWLOG', 'RESET')
def execute_command(self, *args, **options):
"Execute a command and return a parsed response"
pool = self.connection_pool
command_name = args[0]
connection = pool.get_connection(command_name, **options)
try:
connection.send_command(*args)
return self.parse_response(connection, command_name, **options)
except ConnectionError, c:
logger.error("Got exception:%s" % c)
pool.kill_connection(connection)
connection = pool.get_new_connection(command_name, **options)
connection.send_command(*args)
return self.parse_response(connection, command_name, **options)
finally:
pool.release(connection)
def register_script(self, script):
"""
Register a LUA ``script`` specifying the ``keys`` it will touch.
Returns a Script object that is callable and hides the complexity of
deal with scripts, keys, and shas. This is the preferred way to work
with LUA scripts.
"""
return HaScript(self, script)
def evalsha(self, sha, numkeys, *keys_and_args, **options):
"""
Use the ``sha`` to execute a LUA script already registered via EVAL
or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the
key names and argument values in ``keys_and_args``. Returns the result
of the script.
In practice, use the object returned by ``register_script``. This
function exists purely for Redis API completion.
"""
return self.execute_command('EVALSHA', sha, numkeys, *keys_and_args, **options)
def script_load(self, script, **options):
"Load a LUA ``script`` into the script cache. Returns the SHA."
options.update({'parse': 'LOAD'})
return self.execute_command('SCRIPT', 'LOAD', script, **options)
def script_load_all(self, script, **options):
"Load a LUA ``script`` into the script cache. Returns the SHA."
pool = self.connection_pool
if isinstance(pool, HighAvailabilityConnectionPool):
#if it has slaves, register the script on them too.
slaves = pool.get_slaves()
options.update({'parse': 'LOAD'})
args = ['SCRIPT', 'LOAD', script]
#Iterate over all the slaves, and register the script
for connection_key in slaves:
connection = pool.obtain_specified_connection_read(connection_key)
try:
connection.send_command(*args)
rsp = self.parse_response(connection, 'SCRIPT', **options)
logger.debug("obtained: {0} ".format(rsp))
pool.release(connection)
except ConnectionError, c:
pool.kill_connection(connection)
logger.error("Got exception:%s" % c)
return super(HighAvailabilityStrictRedis, self).script_load(script)
# return self.execute_command('SCRIPT', 'LOAD', script, **options)
class HighAvailabilityRedis(Redis):
"""
This strict redis client
"""
def pipeline(self, transaction=False, shard_hint=None):
"""
Returns the write aware pipeline
"""
return WriteAwareRedisPipeline(
self.connection_pool,
self.response_callbacks,
transaction,
shard_hint)
def __str__(self):
return "%s" % self.connection_pool
def slowlog_get(self, value=-1):
"gets the slowlog"
if value > 0:
return self.execute_command('SLOWLOG', 'GET', value)
else:
return self.execute_command('SLOWLOG', 'GET')
def slowlog_len(self):
"gets the slowlog length"
return self.execute_command('SLOWLOG', 'LEN')
def slowlog_reset(self):
"gets the slowlog reset"
return self.execute_command('SLOWLOG', 'RESET')
def execute_command(self, *args, **options):
"Execute a command and return a parsed response"
pool = self.connection_pool
command_name = args[0]
connection = pool.get_connection(command_name, **options)
try:
connection.send_command(*args)
return self.parse_response(connection, command_name, **options)
except ConnectionError, c:
logger.error("Got exception:%s" % c)
pool.kill_connection(connection)
connection = pool.get_new_connection(command_name, **options)
connection.send_command(*args)
return self.parse_response(connection, command_name, **options)
finally:
pool.release(connection)
def register_script(self, script):
"""
Register a LUA ``script`` specifying the ``keys`` it will touch.
Returns a Script object that is callable and hides the complexity of
deal with scripts, keys, and shas. This is the preferred way to work
with LUA scripts.
"""
return HaScript(self, script)
def script_load(self, script, **options):
"Load a LUA ``script`` into the script cache. Returns the SHA."
options.update({'parse': 'LOAD'})
return self.execute_command('SCRIPT', 'LOAD', script, **options)
def script_load_all(self, script, **options):
"Load a LUA ``script`` into the script cache. Returns the SHA."
pool = self.connection_pool
if isinstance(pool, HighAvailabilityConnectionPool):
#if it has slaves, register the script on them too.
slaves = pool.get_slaves()
options.update({'parse': 'LOAD'})
args = ['SCRIPT', 'LOAD', script]
#Iterate over all the slaves, and register the script
for connection_key in slaves:
connection = pool.obtain_specified_connection_read(connection_key)
try:
connection.send_command(*args)
rsp = self.parse_response(connection, 'SCRIPT', **options)
logger.debug("obtained: {0} ".format(rsp))
pool.release(connection)
except ConnectionError, c:
pool.kill_connection(connection)
logger.error("Got exception:%s" % c)
return super(HighAvailabilityRedis, self).script_load(script)
# return self.execute_command('SCRIPT', 'LOAD', script, **options)
def evalsha(self, sha, numkeys, *keys_and_args, **options):
"""
Use the ``sha`` to execute a LUA script already registered via EVAL
or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the
key names and argument values in ``keys_and_args``. Returns the result
of the script.
In practice, use the object returned by ``register_script``. This
function exists purely for Redis API completion.
"""
return self.execute_command('EVALSHA', sha, numkeys, *keys_and_args, **options)
class HaConnection(Connection):
"Manages TCP communication to and from a Redis server"
def __init__(self, host='localhost', port=6379, db=0, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
connection_type='master', parser_class=DefaultParser):
self.connection_type = connection_type
self.pid = os.getpid()
super(HaConnection, self).__init__(host=host, port=port, db=db, password=password, socket_timeout=socket_timeout, encoding=encoding,
encoding_errors=encoding_errors, decode_responses=decode_responses, parser_class=parser_class)
regex = re.compile("redis\.p*call\(['\"]+(.+)['\"],.*\)")
class HaScript(Script):
"An executable LUA script object returned by ``register_script``"
def __init__(self, registered_client, script):
(self.readonly, script) = self._parse_script(script)
super(HaScript, self).__init__(registered_client=registered_client, script=script)
registered_client.script_load_all(script)
def _parse_script(self, script):
"""
removes the lines that start with a comment and that are empty,
and check if there are any write commands
"""
readonly = True
result = ""
lines = script.splitlines()
for l in lines:
t = l.strip()
if(t and t[0] != '#'):
if (readonly and self.check_for_write(t) == True):
readonly = False
result = result + "\n" + t
return (readonly, result)
def check_for_write(self, line):
r = regex.search(line)
if r:
_commands = r.groups()
for w in _commands:
if w.upper() in _write_commands:
return True
return False
def __call__(self, keys=[], args=[], client=None):
"Execute the script, passing any required ``args``"
client = client or self.registered_client
args = tuple(keys) + tuple(args)
options = {'readonly': self.readonly}
# make sure the Redis server knows about the script
if isinstance(client, BasePipeline):
# make sure this script is good to go on pipeline
client.script_load_for_pipeline(self)
try:
return client.evalsha(self.sha, len(keys), *args, **options)
except NoScriptError:
# Maybe the client is accessing the slave, try master
self.sha = client.script_load(self.script)
return client.evalsha(self.sha, len(keys), *args)
class HighAvailabilityConnectionPool(ConnectionPool):
"""
Has a list of servers, with master and slaves info
the slaves is a collection of dictionaries (of arguments)
minimum requirement for the slave is the host or unix_socket_path
example of a slaves:
slaves = [{'host': 'localhost, 'port': 6380}, {'host': '10.16.2.1'}, {'host':'localhost', 'port': 6381} ]
In this case there are 3 slaves, two running on the same localhost, but in different ports, and one running on another box.
"""
# def __init__(self, connection_class=Connection, max_connections=None, db=0, password=None, socket_timeout=None,
# encoding='utf-8', errors='strict', unix_socket_path=None, slaves=None, charset=None):
#
def __init__(self, connection_class=Connection, max_connections=None, slaves=None, **connection_kwargs):
self.connection_class = HaConnection
self.max_connections = max_connections or 2**31
self.pid = os.getpid()
self._master_connections = []
self._slave_connections = []
self.master_kwargs = dict(connection_kwargs)
if 'unix_socket_path' in connection_kwargs:
#This library does not support unix sockets, only tcp sockets
logger.error("Please provide host. this library does not support unix sockets.\nDo not pass unix_socket_path in the arguments!!")
raise Exception('Please provide host. this library does not support unix sockets')
self._base_dict = dict(self.master_kwargs)
self._base_dict.pop('host', None)
self.master_kwargs.update({
'host': connection_kwargs['host'],
})
self._available_slave_connections = {}
self._inuse_slave_connections = {}
self.slaves_kwargs = {}
self._created_slaves_connections = {}
self._removed_slaves = set()
self._removed_master = set()
if slaves:
for s in slaves:
self.add_slave(s)
#Use default info on db #, password,timeout,enoding and encoding errors
#All slaves have been configured.
#If there aren't any slaves, this Connection Pool will behave like the Parent class
self._have_slaves = len(self.slaves_kwargs) > 0
self._available_write_connections = []
self._in_use_write_connections = set()
self._created_write_connections = 0
super(HighAvailabilityConnectionPool, self).__init__(HaConnection, self.max_connections, **self.master_kwargs)
def _checkpid(self):
if self.pid != os.getpid():
self.disconnect()
self.__init__(self.connection_class, self.max_connections, **self.connection_kwargs)
def _get_key(self, host, port):
k = "%s:%s" % (host.lower(), str(port))
return k
def add_slave(self, args):
"""
Adds a slave into the list of slaves available.
If it exist already, it will do nothing
"""
if 'host' in args:
_tmp = dict(self._base_dict)
_tmp.update(args)
key = self._get_key(_tmp['host'], _tmp['port'])
if key in self.slaves_kwargs:
logger.error("slave exists already, will not be added. Need to remove first!!")
else:
self.slaves_kwargs[key] = _tmp
self._available_slave_connections[key] = []
self._inuse_slave_connections[key] = set()
self._created_slaves_connections[key] = 0
if key in self._removed_slaves:
logger.info("Adding a slave that was previously removed.")
self._removed_slaves.remove(key)
else:
logger.error("Slave not valid.\nPlease provide the host for the slave. This library does not support unix sockets.")
raise Exception('Slave not valid. Missing host in the configuration')
def get_connection(self, command_name, *keys, **options):
"Get a connection from the pool"
#If the command is a write command or
#if it's a pipeline but it contains write command or (Take into account that the pipeline should be a WriteAware pipeline)
#If there aren't any slaves, then go to master!!
self._checkpid()
if not self._have_slaves or options.get('readonly', False) == False and (
command_name in _write_commands or options.get('write', False)
):
connection = self._obtain_connection_write()
connection.connection_type = 'master'
else:
connection = self._obtain_connection_read()
connection.connection_type = 'slave'
return connection
def get_new_connection(self, command_name, *keys, **options):
"Get a connection from the pool"
#If the command is a write command or
#if it's a pipeline but it contains write command or (Take into account that the pipeline should be a WriteAware pipeline)
#If there aren't any slaves, then go to master!!
self._checkpid()
if not self._have_slaves or (command_name in _write_commands) or options.get('write', False):
connection = self._make_connection(self.master_kwargs)
self._in_use_write_connections.add(connection)
connection.connection_type = 'master'
else:
key = choice(self.slaves_kwargs.keys())
connection = self._make_slave_connection(key)
self._inuse_slave_connections[key].add(connection)
connection.connection_type = 'slave'
return connection
def _obtain_connection_write(self):
"""
obtains a connection from the master
"""
try:
connection = self._available_write_connections.pop()
except IndexError:
connection = self._make_connection(self.master_kwargs)
connection.connection_type = 'master'
self._in_use_write_connections.add(connection)
return connection
def _obtain_connection_read(self):
"""
obtains a connection from one of the slaves
"""
key = choice(self.slaves_kwargs.keys())
return self.obtain_specified_connection_read(key)
def obtain_specified_connection_read(self, key):
"""
obtains a connection from one of the slaves
"""
try:
connection = self._available_slave_connections[key].pop()
except IndexError:
connection = self._make_slave_connection(key)
connection.connection_type = 'slave'
self._inuse_slave_connections[key].add(connection)
return connection
def make_connection(self):
"""
This should not be called
"""
print "Error this shouldn't be called"
return super(HighAvailabilityConnectionPool, self).make_connection()
def _make_connection(self, kwargs):
"Create a new connection"
#is a master connection if it specifies the arguments or if there aren't any slaves
if self._created_write_connections >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_write_connections += 1
return self.connection_class(**kwargs)
def _make_slave_connection(self, key):
"Create a new slave connection"
#create a connection for the specified slave
if self._created_slaves_connections[key] >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_slaves_connections[key] += 1
return self.connection_class(**self.slaves_kwargs[key])
def release(self, connection):
"""
Releases the connection back to the pool
No need to disconnect!!!
"""
self._checkpid()
if connection.pid == self.pid:
key = self._get_key(connection.host, connection.port)
if connection.connection_type == 'slave':
try:
self._inuse_slave_connections[key].remove(connection)
self._available_slave_connections[key].append(connection)
except Exception:
logger.error("Did not find connection with key[%s] in slaves. connection type:%s" % (key, connection.connection_type))
else:
try:
self._in_use_write_connections.remove(connection)
self._available_write_connections.append(connection)
except Exception:
logger.error("Did not find connection with key[%s] in master. Connection type:%s" % (key, connection.connection_type))
def disconnect(self):
"""
Disconnects all connections in the pool
Make sure to disconnect all read and write connections
"""
#disconnect the master conns
self._disconnect(self._available_write_connections, self._in_use_write_connections)
#Disconnect the slaves connections
for k in self.slaves_kwargs.keys():
self._disconnect(self._inuse_slave_connections.get(k, []), self._available_slave_connections.get(k, []))
def _disconnect(self, available, in_use):
_conns = chain(in_use, available)
for connection in _conns:
connection.disconnect()
def __str__(self):
t = "Master [#Available, #In Use]\n\t%s, [%d,%d]\nSlaves [#Available, #In Use]:" % (self._get_key(self.master_kwargs['host'], self.master_kwargs['port']),
len(self._available_write_connections), len(self._in_use_write_connections))
for k in self.slaves_kwargs:
t += "\n\t%s, [%d,%d]" % (k, len(self._available_slave_connections[k]), len(self._inuse_slave_connections[k]))
return t
def get_slaves(self):
"""
returns a copy of dictionary of slaves arguments, where the key is
host:port
"""
return self.slaves_kwargs.copy()
def get_master(self):
"""
returns a copy of the parameters of the master. It is a dictionary
"""
return self.master_kwargs.copy()
def kill_connection(self, connection):
try:
self._checkpid()
if connection.pid == self.pid:
key = self._get_key(connection.host, connection.port)
if connection.connection_type == 'slave':
self._inuse_slave_connections[key].remove(connection)
else:
self._in_use_write_connections.remove(connection)
connection.disconnect()
except Exception, e:
logger.error("could not disconnect. error: %s:" % e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment