Last active
December 18, 2024 11:02
-
-
Save graingert/380daef873f817d2dab2c9825588ebfb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import anyio.streams.buffered | |
import anyio.abc | |
import socket | |
import functools | |
async def pipe(reader, writer): | |
raw_socket = reader.extra_attributes[anyio.abc.SocketAttribute.raw_socket]() | |
try: | |
while True: | |
try: | |
data = await reader.receive() | |
except ( | |
anyio.EndOfStream, | |
anyio.ClosedResourceError, | |
anyio.BrokenResourceError, | |
): | |
break | |
try: | |
await writer.send(data) | |
except ( | |
anyio.ClosedResourceError, | |
anyio.BrokenResourceError, | |
): | |
break | |
finally: | |
try: | |
try: | |
raw_socket.shutdown(socket.SHUT_RD) | |
except OSError: | |
pass | |
finally: | |
await writer.send_eof() | |
async def handle(client): | |
async with client: | |
buffered = anyio.streams.buffered.BufferedByteReceiveStream(client) | |
with anyio.move_on_after(30): | |
try: | |
header = await buffered.receive_exactly(5) | |
except anyio.IncompleteRead: | |
print("incomplete read") | |
return | |
if header[0:3] != b"\x16\x03\x01": | |
print("wrong tls version, or not handshake") | |
return | |
record_length = int.from_bytes(header[3:5], byteorder="big", signed=False) | |
try: | |
record = await buffered.receive_exactly(record_length) | |
except anyio.IncompleteRead: | |
print("incomplete read 2") | |
return | |
if record[0] != 0x01: | |
print("not client hello") | |
print(record) | |
length = int.from_bytes(record[1:4], byteorder="big", signed=False) | |
if record[4:6] != b"\x03\x03": | |
print(f"bad client version, {record[4:6]}") | |
return | |
client_random = record[6:38] | |
print(len(client_random)) | |
session_id_length = record[38] | |
print(session_id_length) | |
session_id = record[39 : 39 + session_id_length] | |
print(session_id) | |
cipher_suites_len_start = 39 + session_id_length | |
cipher_suites_length = int.from_bytes( | |
record[cipher_suites_len_start : cipher_suites_len_start + 2], | |
byteorder="big", | |
signed=False, | |
) | |
print(cipher_suites_length) | |
cipher_suites_start = cipher_suites_len_start + 2 | |
compression_methods_start = cipher_suites_start + cipher_suites_length | |
cipher_suites = record[cipher_suites_start:compression_methods_start] | |
print(cipher_suites) | |
extensions_len_start = compression_methods_start + 2 | |
if ( | |
record[compression_methods_start : compression_methods_start + 2] | |
!= b"\x01\x00" | |
): | |
print("compression enabled") | |
return | |
if record_length <= extensions_len_start: | |
print("no extensions") | |
return | |
extensions_start = extensions_len_start + 2 | |
extensions_length = int.from_bytes( | |
record[extensions_len_start:extensions_start], | |
byteorder="big", | |
signed=False, | |
) | |
print(extensions_start) | |
if extensions_start + extensions_length != record_length: | |
print("extensions are not the last field") | |
return | |
extension_types = set() | |
extension_type_start = extensions_start | |
sni = None | |
while extension_type_start < record_length: | |
extension_len_start = extension_type_start + 2 | |
extension_type = record[extension_type_start:extension_len_start] | |
if extension_type in extension_types: | |
print("duplicate extension types") | |
return | |
extension_types.add(extension_type) | |
extension_start = extension_len_start + 2 | |
extension_length = int.from_bytes( | |
record[extension_len_start:extension_start], | |
byteorder="big", | |
signed=False, | |
) | |
extension_type_start = extension_start + extension_length | |
if extension_type == b"\x00\x00": | |
sni = record[extension_start:extension_type_start] | |
if extension_type_start != record_length: | |
print("too many extensions") | |
return | |
if sni is None: | |
print("no sni") | |
return | |
sni_extension_length = len(sni) | |
sni_entry_types = set() | |
sni_entry_type_start = 2 | |
sni_list_length = int.from_bytes( | |
sni[0:sni_entry_type_start], byteorder="big", signed=False | |
) | |
if sni_entry_type_start + sni_list_length != sni_extension_length: | |
print("invalid sni length") | |
return | |
domain_name = None | |
while sni_entry_type_start < sni_extension_length: | |
sni_entry_type = sni[sni_entry_type_start] | |
if sni_entry_type in sni_entry_types: | |
print("duplicate sni") | |
return | |
sni_entry_types.add(sni_entry_type) | |
sni_entry_len_start = sni_entry_type_start + 1 | |
sni_entry_start = sni_entry_len_start + 2 | |
sni_entry_length = int.from_bytes( | |
sni[sni_entry_len_start:sni_entry_start], | |
byteorder="big", | |
signed=False, | |
) | |
sni_entry_type_start = sni_entry_start + sni_entry_length | |
if sni_entry_type == 0: | |
domain_name = sni[sni_entry_start:sni_entry_type_start] | |
if sni_entry_type_start != sni_extension_length: | |
print("too much sni") | |
return | |
print(domain_name) | |
print(length, record_length) | |
DEFAULT_RECEIVE_SIZE = 65536 | |
async def send(): | |
try: | |
await server.send(header + record + buffered.buffer) | |
except ( | |
anyio.ClosedResourceError, | |
anyio.BrokenResourceError, | |
): | |
return | |
await pipe(writer=server, reader=client) | |
async with ( | |
await anyio.connect_tcp(domain_name.decode("ascii"), 443) as server, | |
anyio.create_task_group() as tg, | |
): | |
tg.start_soon(send) | |
tg.start_soon(functools.partial(pipe, reader=server, writer=client)) | |
async def main(): | |
listener = await anyio.create_tcp_listener(local_port=1234) | |
await listener.serve(handle) | |
anyio.run(main, backend="trio") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment