import asyncio
from functools import wraps

from aiohttp import web
from websockets import handshake
from websockets import WebSocketCommonProtocol


INDEX = open('index.html').read().encode('utf-8')


class WebSocketResponse(web.Response):

    def __init__(self, request, switch_protocols):
        http11 = request._version == (1, 1)
        get_header = lambda k: dict(request.headers)[k.upper()]
        key = handshake.check_request(get_header)
        if not http11 or not key:
            super('Invalid WebSocket handshake.\n', status=400)
        else:
            headers = dict()
            set_header = headers.__setitem__
            handshake.build_response(set_header, key)
            self.switch_protocols = switch_protocols
            super().__init__(status=101, headers=headers)
            self._keep_alive = True
            request.transport.close = switch_protocols


@asyncio.coroutine
def handle(request):
    return web.Response(body=INDEX)


def websocket(handler):

    @asyncio.coroutine
    @wraps(handler)
    def wrapper(request, *args, **kwargs):
        transport = request.transport
        http_protocol = transport._protocol

        @asyncio.coroutine
        def run_ws_handler(ws):
            yield from handler(ws, request, *args, **kwargs)
            yield from ws.close()

        def switch_protocols():
            ws_protocol = WebSocketCommonProtocol()
            transport._protocol = ws_protocol
            ws_protocol.connection_made(transport)

            # Ensure aiohttp doesn't interfere.
            http_protocol.transport = None

            asyncio.async(run_ws_handler(ws_protocol))

        return WebSocketResponse(request, switch_protocols)

    return wrapper


@websocket
@asyncio.coroutine
def ws(ws, request):
    while ws.open:
        message = yield from ws.recv()
        if message:
            print(ws, request, message)
            yield from ws.send("Hello from aiohttp powered websocket!")


@asyncio.coroutine
def init(loop):
    app = web.Application(loop=loop)
    app.router.add_route('GET', '/', handle)
    app.router.add_route('GET', '/ws', ws)

    server = yield from loop.create_server(
        app.make_handler(),
        '127.0.0.1',
        8080
    )
    print("Server started at http://127.0.0.1:8080")
    return server


loop = asyncio.get_event_loop()
loop.run_until_complete(init(loop))
loop.run_forever()