Created
August 19, 2016 18:12
-
-
Save qix/e51f9583d126102124854a0384ab218b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import hiredis | |
import logging | |
import sys | |
from collections import deque | |
logger = logging.getLogger('pylib.aio.redis_server') | |
class RedisProtocol(asyncio.Protocol): | |
def __init__(self, loop): | |
super().__init__() | |
self.reader = hiredis.Reader() | |
self.loop = loop | |
self.tasks = deque() | |
def connection_made(self, transport): | |
self.transport = transport | |
async def command_received(self, command): | |
raise NotImplementedError() | |
def task_done(self, task): | |
while self.tasks and self.tasks[0].done(): | |
self.write_task(self.tasks.popleft()) | |
def data_received(self, data): | |
self.reader.feed(data) | |
command = self.reader.gets() | |
while command: | |
task = asyncio.ensure_future(self.command_received(command), loop=self.loop) | |
self.tasks.append(task) | |
task.add_done_callback(self.task_done) | |
command = self.reader.gets() | |
def write_task(self, task): | |
try: | |
self.write_value(task.result()) | |
except: | |
logger.exception('Redis command threw an exception') | |
self.write_error('Internal server error') | |
def write_value(self, value): | |
if value is None: | |
self.write_null() | |
elif type(value) is bytes: | |
self.write_bytes(value) | |
elif type(value) is str: | |
self.write_bytes(value.encode('utf-8')) | |
elif type(value) is int: | |
self.write_integer(value) | |
else: | |
raise Exception('redis return type not handled') | |
def write_simple_string(self, message, symbol=b'+'): | |
self.transport.write(symbol + message.encode('utf-8') + b'\r\n') | |
def write_integer(self, value): | |
assert -(2 ** 63) <= value < 2 ** 63, 'value was not in integer range' | |
return self.write_simple_string(str(value), symbol=b':') | |
def write_error(self, message): | |
return self.write_simple_string(message, symbol=b'-') | |
def write_null(self): | |
return self.write_simple_string('-1', symbol=b'$') | |
def write_bytes(self, value): | |
length = str(len(value)).encode('ascii') | |
self.transport.write(b'$' + length + b'\r\n' + value + b'\r\n') | |
class RedisCommandHandler(RedisProtocol): | |
def __init__(self, loop, commands): | |
super().__init__(loop) | |
self.commands = commands | |
async def command_received(self, command): | |
name = command[0].decode('utf-8').lower() | |
return self.commands[name](*command[1:]) | |
class RedisServer(object): | |
def __init__(self): | |
self.commands = {} | |
def command(self, fn): | |
name = fn.__name__.lower() | |
assert not name in self.commands, 'Command name already taken' | |
self.commands[name] = fn | |
def make_handler(self, loop): | |
def connection_handler(): | |
return RedisCommandHandler(loop=loop, commands=self.commands) | |
return connection_handler | |
if __name__ == '__main__': | |
data = {} | |
server = RedisServer() | |
@server.command | |
def get(key): | |
return data.get(key) | |
@server.command | |
def set(key, value): | |
data[key] = value | |
loop = asyncio.get_event_loop() | |
coro = loop.create_server(server.make_handler(loop), '127.0.0.1', 8888) | |
server = loop.run_until_complete(coro) | |
print('Server listening on :8888') | |
try: | |
loop.run_forever() | |
except KeyboardInterrupt: | |
pass | |
finally: | |
server.close() | |
loop.run_until_complete(server.wait_closed()) | |
loop.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment