Last active
August 29, 2015 14:04
-
-
Save meganehouser/79d3dcdda3abeb00adc5 to your computer and use it in GitHub Desktop.
simple key value store (in memory) inspired by tristanwietsma/jack v0.1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
simple key value store (in memory) | |
inspired by tristanwietsma/jack v0.1 | |
https://github.com/tristanwietsma/jack | |
""" | |
from argparse import ArgumentParser | |
from asyncio import ( | |
Protocol, | |
coroutine, | |
get_event_loop, | |
Lock, | |
Task) | |
class Store(): | |
def __init__(self): | |
# map for key value | |
self.data_map = {} | |
# map for channel | |
self.sub_map = {} | |
def get(self, key): | |
"""get value""" | |
key_exists = key in self.data_map.keys() | |
value = self.data_map.get(key) | |
return key_exists, value | |
def set(self, key, value): | |
"""set key to value""" | |
self.data_map[key] = value | |
return True | |
def delete(self, keys): | |
"""Delete a key""" | |
for key in keys: | |
if key in self.data_map.keys(): | |
self.data_map.pop(key) | |
def publish(self, key): | |
"""Publish a stream to akey""" | |
def publisher(): | |
while True: | |
try: | |
value = yield | |
self.set(key, value) | |
self._update_subscribers(key, value) | |
except GeneratorExit: | |
break | |
pub = publisher() | |
pub.send(None) | |
return pub | |
def subscribe(self, key, out_going): | |
"""Subscribe to published changes on a key""" | |
_, has_subs = self._fetch_subscribers(key) | |
if has_subs: | |
self.sub_map[key].append(out_going) | |
else: | |
self.sub_map[key] = [out_going] | |
def unsubscribe(self, key, out_going): | |
"""Unsubscribe to published changes on a key""" | |
subs, has_subs = self._fetch_subscribers(key) | |
if has_subs: | |
for sub in subs: | |
self.sub_map[key].remove(sub) | |
def _fetch_subscribers(self, key): | |
has_subs = key in self.sub_map.keys() | |
value = self.sub_map.get(key) | |
return value, has_subs | |
def _update_subscribers(self, key, value): | |
subs, has_subs = self._fetch_subscribers(key) | |
if has_subs: | |
for out in subs: | |
try: | |
out(value) | |
except Exception as e: | |
print('update subscribe error:', e) | |
self.unsubscribe(key, out) | |
class JackServer(Protocol): | |
def __init__(self, db): | |
super().__init__() | |
self.db = db | |
def connection_made(self, transport): | |
peername = transport.get_extra_info('peername') | |
print('connction from {}'.format(peername)) | |
self.transport = transport | |
def data_received(self, data): | |
print('data received', data.decode()) | |
args = data.decode().split(" ") | |
verb = args[0] | |
if verb == "GET": | |
if len(args) != 2: | |
return | |
ok, value = self.db.get(args[1]) | |
if ok: | |
self.transport.write(value.encode()) | |
else: | |
self.transport.write("(None)".encode()) | |
elif verb == "SET": | |
if len(args) != 3: | |
return | |
ok = self.db.set(args[1], args[2]) | |
if ok: | |
self.transport.write("OK".encode()) | |
else: | |
self.transport.write("FAIL".encode()) | |
elif verb == "DEL": | |
if len(args) < 2: | |
return | |
self.db.delete(args[1:]) | |
self.transport.write("OK".encode()) | |
elif verb == "PUB": | |
if len(args) != 2: | |
return | |
publisher = self.db.publish(args[1]) | |
self.transport.write("READY".encode()) | |
loop = get_event_loop() | |
@coroutine | |
def pub_sender(): | |
while True: | |
try: | |
data = yield from loop.sock_recv(self.transport.get_extra_info('socket'), 1024) | |
publisher.send(data.decode()) | |
except Exception as e: | |
print("ERROR:", e) | |
publisher.close() | |
self.transport.close() | |
Task(pub_sender(), loop=loop) | |
elif verb == "SUB": | |
if len(args) != 2: | |
return | |
self.db.subscribe(args[1], lambda value:self.transport.write(value.encode())) | |
def start_server(port_no): | |
db = Store() | |
protocol_factory = lambda : JackServer(db) | |
loop = get_event_loop() | |
jack = loop.create_server(protocol_factory, '127.0.0.1', port_no) | |
server = loop.run_until_complete(jack) | |
print('serving on {}'.format(server.sockets[0].getsockname())) | |
try: | |
loop.run_forever() | |
except KeyboardInterrupt: | |
print('exit') | |
finally: | |
server.close() | |
loop.close() | |
if __name__ == '__main__': | |
parser = ArgumentParser() | |
parser.add_argument('port', default=2000, help='tcp port number') | |
args = parser.parse_args() | |
start_server(args.port) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment