Skip to content

Instantly share code, notes, and snippets.

@meganehouser
Last active August 29, 2015 14:04
Show Gist options
  • Save meganehouser/79d3dcdda3abeb00adc5 to your computer and use it in GitHub Desktop.
Save meganehouser/79d3dcdda3abeb00adc5 to your computer and use it in GitHub Desktop.
simple key value store (in memory) inspired by tristanwietsma/jack v0.1
"""
simple key value store (in memory)
inspired by tristanwietsma/jack v0.1
https://github.com/tristanwietsma/jack
"""
from argparse import ArgumentParser
from asyncio import (
Protocol,
coroutine,
get_event_loop,
Lock,
Task)
class Store():
def __init__(self):
# map for key value
self.data_map = {}
# map for channel
self.sub_map = {}
def get(self, key):
"""get value"""
key_exists = key in self.data_map.keys()
value = self.data_map.get(key)
return key_exists, value
def set(self, key, value):
"""set key to value"""
self.data_map[key] = value
return True
def delete(self, keys):
"""Delete a key"""
for key in keys:
if key in self.data_map.keys():
self.data_map.pop(key)
def publish(self, key):
"""Publish a stream to akey"""
def publisher():
while True:
try:
value = yield
self.set(key, value)
self._update_subscribers(key, value)
except GeneratorExit:
break
pub = publisher()
pub.send(None)
return pub
def subscribe(self, key, out_going):
"""Subscribe to published changes on a key"""
_, has_subs = self._fetch_subscribers(key)
if has_subs:
self.sub_map[key].append(out_going)
else:
self.sub_map[key] = [out_going]
def unsubscribe(self, key, out_going):
"""Unsubscribe to published changes on a key"""
subs, has_subs = self._fetch_subscribers(key)
if has_subs:
for sub in subs:
self.sub_map[key].remove(sub)
def _fetch_subscribers(self, key):
has_subs = key in self.sub_map.keys()
value = self.sub_map.get(key)
return value, has_subs
def _update_subscribers(self, key, value):
subs, has_subs = self._fetch_subscribers(key)
if has_subs:
for out in subs:
try:
out(value)
except Exception as e:
print('update subscribe error:', e)
self.unsubscribe(key, out)
class JackServer(Protocol):
def __init__(self, db):
super().__init__()
self.db = db
def connection_made(self, transport):
peername = transport.get_extra_info('peername')
print('connction from {}'.format(peername))
self.transport = transport
def data_received(self, data):
print('data received', data.decode())
args = data.decode().split(" ")
verb = args[0]
if verb == "GET":
if len(args) != 2:
return
ok, value = self.db.get(args[1])
if ok:
self.transport.write(value.encode())
else:
self.transport.write("(None)".encode())
elif verb == "SET":
if len(args) != 3:
return
ok = self.db.set(args[1], args[2])
if ok:
self.transport.write("OK".encode())
else:
self.transport.write("FAIL".encode())
elif verb == "DEL":
if len(args) < 2:
return
self.db.delete(args[1:])
self.transport.write("OK".encode())
elif verb == "PUB":
if len(args) != 2:
return
publisher = self.db.publish(args[1])
self.transport.write("READY".encode())
loop = get_event_loop()
@coroutine
def pub_sender():
while True:
try:
data = yield from loop.sock_recv(self.transport.get_extra_info('socket'), 1024)
publisher.send(data.decode())
except Exception as e:
print("ERROR:", e)
publisher.close()
self.transport.close()
Task(pub_sender(), loop=loop)
elif verb == "SUB":
if len(args) != 2:
return
self.db.subscribe(args[1], lambda value:self.transport.write(value.encode()))
def start_server(port_no):
db = Store()
protocol_factory = lambda : JackServer(db)
loop = get_event_loop()
jack = loop.create_server(protocol_factory, '127.0.0.1', port_no)
server = loop.run_until_complete(jack)
print('serving on {}'.format(server.sockets[0].getsockname()))
try:
loop.run_forever()
except KeyboardInterrupt:
print('exit')
finally:
server.close()
loop.close()
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('port', default=2000, help='tcp port number')
args = parser.parse_args()
start_server(args.port)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment