Last active
May 5, 2022 05:25
-
-
Save mosquito/4dbfacd51e751827cda7ec9761273e95 to your computer and use it in GitHub Desktop.
AIOHTTP async proxy streaming
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import platform | |
import unittest | |
from http import HTTPStatus | |
from urllib.parse import urlencode, unquote | |
import aiohttp | |
import asynctest | |
from aiohttp import web | |
from aiohttp.test_utils import TestClient | |
from multidict import MultiDict | |
from yarl import URL | |
async def redirect(request: web.Request, url: str) -> 'web.Response': | |
''' Stream proxied HTTP request ''' | |
parsed_url = URL(url) | |
headers = MultiDict(request.headers) | |
headers['Host'] = parsed_url.host | |
headers['Router-Host'] = request.host | |
headers['User-Agent'] = ("aiohttp/%(aiohttp_version)s (%(system_version)s)") % { | |
"aiohttp_version": aiohttp.__version__, | |
"system_version": platform.version(), | |
} | |
target_url = parsed_url | |
if not parsed_url.path: | |
target_url.path = request.path | |
body = await request.read() | |
async with aiohttp.ClientSession(headers=headers) as session: | |
async with session.request(request.method, target_url, data=body) as response: | |
proxied_response = web.Response(headers=response.headers, status=response.status) | |
if response.headers.get('Transfer-Encoding', '').lower() == 'chunked': | |
proxied_response.enable_chunked_encoding() | |
await proxied_response.prepare(request) | |
async for data in response.content.iter_any(): | |
proxied_response.write(data) | |
if data: | |
await proxied_response.drain() | |
return proxied_response | |
class TestRedirectHandler(web.View): | |
async def get(self): | |
url = unquote(self.request.query.get('url')) | |
return await redirect(self.request, url) | |
async def post(self): | |
url = (await self.request.json()).get('url') | |
return await redirect(self.request, url) | |
class RedirectBase(asynctest.TestCase): | |
async def get_application(self): | |
return web.Application(loop=self.loop) | |
async def setUp(self): | |
super().setUp() | |
self.app = await self.get_application() | |
self.app.router.add_route('*', '/redirect', TestRedirectHandler) | |
self.client = TestClient(self.app, loop=self.loop) | |
await self.client.start_server() | |
async def tearDown(self): | |
await self.client.close() | |
super().tearDown() | |
class TestRedirect(RedirectBase): | |
async def test_redirect_get_simple(self): | |
query = urlencode({'url': 'https://httpbin.org/user-agent'}) | |
redirected_response = await self.client.request('GET', '/redirect?{}'.format(query)) | |
self.assertEqual(redirected_response.status, HTTPStatus.OK) | |
self.assertIsNotNone(await redirected_response.text()) | |
async def test_redirect_post(self): | |
data = {'url': 'https://httpbin.org/post', 'shmurl': 'qwerty'} | |
redirected_response = await self.client.request('POST', '/redirect', json=data) | |
self.assertEqual(redirected_response.status, HTTPStatus.OK) | |
result = await redirected_response.json() | |
self.assertIn('json', result) | |
self.assertEqual(result['json']['shmurl'], 'qwerty') | |
async def test_redirect_get_chunked(self): | |
chunks = 10 | |
query = urlencode({'url': 'https://httpbin.org/stream/{}'.format(chunks)}) | |
redirected_response = await self.client.request('GET', '/redirect?{}'.format(query)) | |
self.assertEqual(redirected_response.status, HTTPStatus.OK) | |
result = await redirected_response.text() | |
self.assertEqual(len(result.splitlines()), chunks) | |
async def test_redirect_get_long_blob(self): | |
length = 10240 | |
query = urlencode({'url': 'https://httpbin.org/bytes/{}'.format(length)}) | |
redirected_response = await self.client.request('GET', '/redirect?{}'.format(query)) | |
self.assertEqual(redirected_response.status, HTTPStatus.OK) | |
result = await redirected_response.read() | |
self.assertEqual(len(result), length) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment