Created
November 23, 2020 11:09
-
-
Save Mause/be2c6faacaa6fae97d12838c79e0b4dd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from typing import AsyncGenerator | |
from fastapi import Depends, FastAPI, Request | |
from fastapi.responses import StreamingResponse | |
from fastapi.testclient import TestClient | |
from msgpack import Packer, Unpacker | |
@dataclass | |
class Box: | |
stream: AsyncGenerator | |
def __aiter__(self): | |
return self.stream.__aiter__() | |
async def get_message_pack(request: Request) -> Box: | |
async def internal(): | |
unpacker = Unpacker(raw=False) | |
while True: | |
res = await request.receive() | |
unpacker.feed(res['body']) | |
try: | |
yield next(unpacker) | |
except StopIteration: | |
pass # not enough data for another segment yet | |
if not res.get('more_body'): | |
break | |
return Box(internal()) | |
class MessagePackResponse(StreamingResponse): | |
def __init__(self, content, *args, **kwargs): | |
packer = Packer() | |
async def internal(): | |
if isinstance(content, AsyncGenerator): | |
async for item in content: | |
yield packer.pack(item) | |
else: | |
for item in content: | |
yield packer.pack(item) | |
super().__init__(internal(), *args, **kwargs) | |
def main(): | |
app = FastAPI() | |
@app.post('/') | |
async def whatever(item=Depends(get_message_pack)): | |
async def internal(): | |
count = 0 | |
async for segment in item: | |
print(segment) | |
yield {'reply': count} | |
count += 1 | |
return MessagePackResponse(internal()) | |
tc = TestClient(app) | |
def stream(): | |
packer = Packer() | |
for i in range(5): | |
yield packer.pack({'hello': i}) | |
unp = Unpacker(tc.post('/', data=stream(), stream=True).raw) | |
for thing in unp: | |
print(thing) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment