Last active
May 16, 2019 05:40
-
-
Save ahopkins/756ab95069998c8899719196d20e3c1e to your computer and use it in GitHub Desktop.
Limetree POC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__version__ = "0.1-prealpha" | |
__author__ = "Adam Hopkins" | |
import trio | |
from itertools import count | |
from functools import partial | |
from functools import lru_cache | |
from collections import defaultdict | |
from typing import Optional, List | |
from datetime import datetime | |
BUFSIZE = 16384 | |
TIMEOUT = 10 | |
SLEEP_TIME = 0.01 | |
TYPES = ("str", "int", "bool") | |
BOOL_TRUE = ("y", "Y", "yes", "YES", "Yes", "true", "True", "TRUE", "1") | |
BOOL_FALSE = ("n", "N", "no", "NO", "No", "false", "False", "FALSE", "0") | |
EOM = b"\r\n" | |
LOGO = """ | |
___ ___ __ __ _______ _______ ______ _______ _______ | |
| | | || |_| || || || _ | | || | | |
| | | || || ___||_ _|| | || | ___|| ___| | |
| | | || || |___ | | | |_||_ | |___ | |___ | |
| |___ | || || ___| | | | __ || ___|| ___| | |
| || || ||_|| || |___ | | | | | || |___ | |___ | |
|_______||___||_| |_||_______| |___| |___| |_||_______||_______| | |
""" | |
################################################################################ | |
################################################################################ | |
################################################################################ | |
class Route(defaultdict): | |
__repr__ = dict.__repr__ | |
def __init__(self): | |
super().__init__() | |
self["routes"] = {} | |
@property | |
def routes(self): | |
return self["routes"] | |
@property | |
def handler(self): | |
return self.get("handler", None) | |
class Router: | |
_mapping = Route() | |
def __init__(self): | |
self.__class__._hydrate_route(self._mapping) | |
@property | |
def mapping(self) -> Route: | |
return self._mapping | |
@classmethod | |
def add(cls, path: str, handler) -> None: | |
# TODO: | |
# - At startup, check each level to see if the path exists and | |
# raise a warning if there is a potential collision | |
# - Predefine protected keywords for a request, and raise an Exception | |
# if a keyword is assigned to a dynamic path variable | |
parts = cls._get_parts(path) | |
pointer = cls._mapping | |
for i, part in enumerate(parts): | |
if part in pointer["routes"]: | |
current = pointer["routes"][part] | |
else: | |
current = Route() | |
h = handler if i + 1 == len(parts) else None | |
if h: | |
current["handler"] = (h, path) | |
pointer["routes"][part] = current | |
pointer = current | |
@classmethod | |
def _hydrate_route(cls, route: Route) -> Route: | |
if route.handler: | |
handler_class, path = route.handler | |
if handler_class: | |
handler = handler_class(path) | |
route["handler"] = handler | |
if route.routes: | |
for key, val in route.routes.items(): | |
route.routes[key] = cls._hydrate_route(val) | |
return route | |
@staticmethod | |
def _get_parts(path: str) -> List[str]: | |
if not path.startswith("/"): | |
raise Exception("Path must start with a /") | |
path = path[1:] | |
if path.endswith("/"): | |
path = path[:-1] | |
parts = path.split("/") | |
return parts | |
@lru_cache(maxsize=1024) | |
def match(self, path: str, mapping: Optional[Route] = None) -> tuple: | |
parts = self._get_parts(path) | |
args = [] | |
kwargs = {} | |
if mapping is None: | |
mapping = self._mapping | |
path_length = len(parts) | |
match_length = 0 | |
for part in parts: | |
if part in mapping.get("routes"): | |
mapping = mapping.get("routes")[part] | |
match_length += 1 | |
# TODO: | |
# - Optimize the key matching. | |
# At any given level, only one dynamic key per type should be allowed. | |
# Therefore, should precomile to make sure that there are no duplicate | |
# types at a sibling slot in the path. | |
# Also, the following code needs to be REVAMPED so that it loops | |
# and recursively matches against the type. | |
elif any([r for r in mapping.get("routes").keys() if ":" in r]): | |
for r in [r for r in mapping.get("routes").keys() if ":" in r]: | |
key, typ = r.split(":") | |
if not typ: | |
typ = "str" | |
if typ in TYPES: | |
if typ == "int": | |
try: | |
part = int(part) | |
except Exception: | |
break | |
if typ == "bool": | |
if part in BOOL_TRUE: | |
part = True | |
elif part in BOOL_FALSE: | |
part = False | |
else: | |
break | |
if key: | |
kwargs.update({key: part}) | |
else: | |
args.append(part) | |
else: | |
raise Exception("Unacceptable type") | |
mapping = mapping.get("routes")[r] | |
match_length += 1 | |
break | |
else: | |
raise Exception("No route found") | |
if match_length != path_length: | |
raise Exception("No route found, wrong match length") | |
return mapping.get("handler", None), args, kwargs | |
def display_mapping(self, mapping: Optional[Route] = None) -> None: | |
if mapping is None: | |
print("Mapped endpoints:") | |
mapping = self.mapping | |
if mapping.handler: | |
print(f"{mapping.handler} at {mapping.handler.path}") | |
if mapping.routes: | |
for route in mapping.routes.values(): | |
self.display_mapping(route) | |
class Endpoint: | |
def __init_subclass__(cls, path: str, **kwargs): | |
Router.add(path, cls) | |
def __init__(self, path: str): | |
self.path = path | |
def __repr__(self): | |
return f" <Endpoint:{self.__class__.__name__}>" | |
class Request: | |
def __init__(self, args, kwargs): | |
self.args = args | |
for key, val in kwargs.items(): | |
setattr(self, key, val) | |
class Response: | |
body: bytes | |
def __init__(self, body): | |
self.body = body | |
class ConnectionManager: | |
_connection_ident_counter = count() | |
_connections = {} | |
async def __aenter__(self): | |
pass | |
async def __aexit__(self, *args, **kwargs): | |
pass | |
# TODO: | |
# - Make sure all connections are closed | |
def _close_connection(self, ident): | |
# print(f' closing connection {ident}') | |
del self._connections[ident] | |
return | |
def connect(self, connection): | |
ident = self.get_next_ident() | |
self._connections[ident] = connection | |
return ident | |
def get_next_ident(self): | |
return next(self._connection_ident_counter) | |
class HttpStream: | |
def __init__(self, stream): | |
self.stream = stream | |
async def send(self, body): | |
# message = bytes(body + EOM, "utf-8") | |
message = body + EOM | |
await self.stream.send_all(message) | |
try: | |
await self.stream.send_eof() | |
except trio.BrokenStreamError: | |
# They're already gone, nothing to do | |
return | |
class BaseService: | |
def __init__(self, server, manager, host, port): | |
print( | |
f"\n{self.__class__.__name__} listening on http://{host}:{port}\n" | |
) | |
self.server = server | |
self.manager = manager | |
self.host = host | |
self.port = port | |
@classmethod | |
async def starter(cls, server, manager, host, port): | |
print(f" {cls.__name__} starter") | |
instance = cls(server, manager, host, port) | |
await trio.serve_tcp(instance._runner, port=port, host=host) | |
async def _runner(self, stream): | |
raise Exception("Not implemented") | |
class IncomingHttpService(BaseService): | |
async def _runner(self, stream): | |
http_stream = HttpStream(stream) | |
ident = self.manager.connect(http_stream) | |
# print(f'\nConnection {ident} established') | |
while True: | |
packet = await stream.receive_some(BUFSIZE) | |
print(f" Conn {ident}: received {packet}") | |
endpoint, args, kwargs = self.server.router.match("/api/hello") | |
# print(f' Found: {endpoint}') | |
request = Request(args, kwargs) | |
retval = await endpoint.get(request) | |
response = Response(retval) | |
await http_stream.send(response.body) | |
print(f" Sent: {response.body}") | |
return self.manager._close_connection(ident) | |
# if not packet: | |
# return self.manager._close_connection(ident) | |
# await trio.sleep(SLEEP_TIME) | |
class Server: | |
_attached_services = [] | |
def __init__(self): | |
self.router = Router() | |
def attach(self, method, *args, **kwargs): | |
# TODO: | |
# - Do not allow services to attach after .service() has started | |
self._attached_services.append((method, args, kwargs)) | |
async def start_services(self): | |
for service, args, kwargs in self._attached_services: | |
manager = ConnectionManager() | |
async with manager: | |
async with trio.open_nursery() as nursery: | |
nursery.start_soon( | |
partial( | |
service.starter, self, manager, *args, **kwargs | |
) | |
) | |
def start(self): | |
print("\nServer starting") | |
trio.run(self.start_services) | |
def close(self): | |
print("\nServer closed") | |
def run(host="localhost", port=7777): | |
from aoiklivereload import LiveReloader | |
import os | |
os.system("cls" if os.name == "nt" else "clear") | |
reloader = LiveReloader() | |
reloader.start_watcher_thread() | |
print(f"By: {__author__}") | |
print(f"Version: {__version__}") | |
print(f"Time: {datetime.now()}") | |
print(LOGO) | |
server = Server() | |
server.attach(IncomingHttpService, host=host, port=port) | |
server.router.display_mapping() | |
try: | |
server.start() | |
except KeyboardInterrupt: | |
server.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import limetree | |
class Hello(limetree.Endpoint, path="/api/hello"): | |
async def get(self, *args, **kwargs): | |
return b"hello" | |
if __name__ == "__main__": | |
limetree.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment