Skip to content

Instantly share code, notes, and snippets.

@qix
Created August 19, 2016 18:12
Show Gist options
  • Save qix/e51f9583d126102124854a0384ab218b to your computer and use it in GitHub Desktop.
Save qix/e51f9583d126102124854a0384ab218b to your computer and use it in GitHub Desktop.
import asyncio
import hiredis
import logging
import sys
from collections import deque
logger = logging.getLogger('pylib.aio.redis_server')
class RedisProtocol(asyncio.Protocol):
def __init__(self, loop):
super().__init__()
self.reader = hiredis.Reader()
self.loop = loop
self.tasks = deque()
def connection_made(self, transport):
self.transport = transport
async def command_received(self, command):
raise NotImplementedError()
def task_done(self, task):
while self.tasks and self.tasks[0].done():
self.write_task(self.tasks.popleft())
def data_received(self, data):
self.reader.feed(data)
command = self.reader.gets()
while command:
task = asyncio.ensure_future(self.command_received(command), loop=self.loop)
self.tasks.append(task)
task.add_done_callback(self.task_done)
command = self.reader.gets()
def write_task(self, task):
try:
self.write_value(task.result())
except:
logger.exception('Redis command threw an exception')
self.write_error('Internal server error')
def write_value(self, value):
if value is None:
self.write_null()
elif type(value) is bytes:
self.write_bytes(value)
elif type(value) is str:
self.write_bytes(value.encode('utf-8'))
elif type(value) is int:
self.write_integer(value)
else:
raise Exception('redis return type not handled')
def write_simple_string(self, message, symbol=b'+'):
self.transport.write(symbol + message.encode('utf-8') + b'\r\n')
def write_integer(self, value):
assert -(2 ** 63) <= value < 2 ** 63, 'value was not in integer range'
return self.write_simple_string(str(value), symbol=b':')
def write_error(self, message):
return self.write_simple_string(message, symbol=b'-')
def write_null(self):
return self.write_simple_string('-1', symbol=b'$')
def write_bytes(self, value):
length = str(len(value)).encode('ascii')
self.transport.write(b'$' + length + b'\r\n' + value + b'\r\n')
class RedisCommandHandler(RedisProtocol):
def __init__(self, loop, commands):
super().__init__(loop)
self.commands = commands
async def command_received(self, command):
name = command[0].decode('utf-8').lower()
return self.commands[name](*command[1:])
class RedisServer(object):
def __init__(self):
self.commands = {}
def command(self, fn):
name = fn.__name__.lower()
assert not name in self.commands, 'Command name already taken'
self.commands[name] = fn
def make_handler(self, loop):
def connection_handler():
return RedisCommandHandler(loop=loop, commands=self.commands)
return connection_handler
if __name__ == '__main__':
data = {}
server = RedisServer()
@server.command
def get(key):
return data.get(key)
@server.command
def set(key, value):
data[key] = value
loop = asyncio.get_event_loop()
coro = loop.create_server(server.make_handler(loop), '127.0.0.1', 8888)
server = loop.run_until_complete(coro)
print('Server listening on :8888')
try:
loop.run_forever()
except KeyboardInterrupt:
pass
finally:
server.close()
loop.run_until_complete(server.wait_closed())
loop.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment