Last active
September 8, 2022 06:59
-
-
Save whg517/32270e9f2c7d7bec1d78e9eb0a94803e to your computer and use it in GitHub Desktop.
sample-aio-server.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import logging | |
from asyncio import Protocol, transports, AbstractEventLoop | |
from typing import Optional | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
def get_remote_addr(transport): | |
socket_info = transport.get_extra_info("socket") | |
if socket_info is not None: | |
try: | |
info = socket_info.getpeername() | |
return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None | |
except OSError: | |
# This case appears to inconsistently occur with uvloop | |
# bound to a unix domain socket. | |
return None | |
info = transport.get_extra_info("peername") | |
if info is not None and isinstance(info, (list, tuple)) and len(info) == 2: | |
return str(info[0]), int(info[1]) | |
return None | |
def get_local_addr(transport): | |
socket_info = transport.get_extra_info("socket") | |
if socket_info is not None: | |
info = socket_info.getsockname() | |
return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None | |
info = transport.get_extra_info("sockname") | |
if info is not None and isinstance(info, (list, tuple)) and len(info) == 2: | |
return str(info[0]), int(info[1]) | |
return None | |
class MyProtocol(Protocol): | |
def __init__(self, application): | |
self.app = application | |
self._transport: Optional[transports.Transport] = None | |
self._data = None | |
def data_received(self, data: bytes) -> None: | |
self._data = data | |
self.handle_request() | |
def eof_received(self) -> Optional[bool]: | |
return super().eof_received() | |
def connection_made(self, transport: transports.Transport) -> None: | |
self._transport = transport | |
logger.info(f'Accept client: {get_remote_addr(transport)}') | |
def connection_lost(self, exc: Optional[Exception]) -> None: | |
logger.info(f'Client: {get_remote_addr(self._transport)} close') | |
def handle_request(self): | |
resp = self.app(self.receiver, self.send) | |
self.send(resp) | |
def receiver(self): | |
"""""" | |
return self._data | |
def send(self, resp: str): | |
"""""" | |
if resp is None: | |
resp = '' | |
self._transport.write(resp.encode()) | |
class BaseMiddleware: | |
"""This is not user middleware""" | |
def __init__(self, application): | |
self.app = application | |
def __call__(self, receiver, send): | |
self.app(receiver, send) | |
class ExceptionMiddleWare(BaseMiddleware): | |
"""Unified handling exception""" | |
def __call__(self, receiver, send): | |
try: | |
self.app(receiver, send) | |
except Exception as e: | |
logger.exception(e) | |
send('error') | |
class Application: | |
def __init__( | |
self, | |
host: Optional[str] = '127.0.0.1', | |
port: Optional[int] = 8080, | |
*, | |
loop: Optional[AbstractEventLoop] = None | |
): | |
self._host = host | |
self._port = port | |
self._loop = loop or asyncio.get_event_loop() | |
self._middleware = [] | |
self._routes = {} | |
self.app = self | |
def middleware(self, *args): | |
for i in args: | |
self._middleware.append(i) | |
async def create_server(self): | |
s = await self._loop.create_server( | |
lambda: MyProtocol(self.app), | |
self._host, | |
self._port | |
) | |
socket = s.sockets[0] | |
host = socket.getsockname()[0] | |
port = socket.getsockname()[1] | |
logger.info(f'Server start: http://{host}:{port}') | |
return s | |
async def start(self): | |
mws = [ExceptionMiddleWare] | |
for mw in mws: | |
self.app = mw(self.app) | |
await self.create_server() | |
def run(self): | |
self._loop.create_task(self.start()) | |
self._loop.run_forever() | |
def __call__(self, receiver, send): | |
for key, value in self._routes.items(): | |
resp = value(receiver()) | |
send(resp) | |
def get(self, route): | |
def _wrapper(func): | |
def __wrapper(*args, **kwargs): | |
return func(*args, **kwargs) | |
self._routes.setdefault(route, __wrapper) | |
return __wrapper | |
return _wrapper | |
# Http framework end | |
# ################################################################# | |
# Http application start | |
app = Application() | |
@app.get('/') | |
def index(request): | |
print(request) | |
resp = 'hello world' | |
return f'HTTP/1.1 200 Success\r\nContent-Length: {len(resp)}\r\n\r\n{resp}' | |
if __name__ == '__main__': | |
app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment