-
-
Save zgrge/be704c80b8775024dcfd072a2d005e09 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import trio | |
import struct | |
import traceback | |
################################################################ | |
# This part is a helper for reading N bytes from a stream | |
################################################################ | |
class UnexpectedEOFError(Exception): | |
pass | |
# Utility function | |
async def receive_exactly(stream, num_bytes): | |
data = bytearray() | |
while len(data) < num_bytes: | |
chunk = await stream.receive_some(num_bytes - len(data)) | |
if not chunk: | |
raise UnexpectedEOFError("other side closed connection") | |
data += chunk | |
assert len(data) == num_bytes | |
return data | |
################################################################ | |
# This uses receive_exactly to read messages that start with a 2 byte length | |
# field | |
################################################################ | |
async def read_message(stream): | |
header = await receive_exactly(stream, 2) | |
# This assumes the size field is big-endian; if it's little-endian use | |
# "<H" instead | |
(message_size,) = struct.unpack(">H", header) | |
body = await receive_exactly(stream, message_size) | |
return header + body | |
################################################################ | |
# This is the actual rewrite logic, you need to fill it in! | |
################################################################ | |
def rewrite_request(request): | |
# ... fill this in ... | |
return request | |
################################################################ | |
# And this is the I/O code to glue it together | |
################################################################ | |
async def handle_one_client(client_stream): | |
try: | |
async with await trio.open_tcp_stream(SERVER_HOST, SERVER_PORT) as server_stream: | |
request = await read_message(client_stream) | |
rewritten_request = rewrite_request(request) | |
await server_stream.send_all(rewritten_request) | |
response = await read_message(server_stream) | |
await client_stream.send_all(response) | |
except Exception: | |
# how do you want to handle errors? maybe log them and then throw them | |
# away? | |
# ... fill this in ... | |
print("Got an error:") | |
traceback.print_exc() | |
async def main(): | |
await trio.serve_tcp(handle_one_client, PROXY_PORT) | |
trio.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment