Skip to content

Instantly share code, notes, and snippets.

@graingert
Last active December 18, 2024 11:02
Show Gist options
  • Save graingert/380daef873f817d2dab2c9825588ebfb to your computer and use it in GitHub Desktop.
Save graingert/380daef873f817d2dab2c9825588ebfb to your computer and use it in GitHub Desktop.
import anyio.streams.buffered
import anyio.abc
import socket
import functools
async def pipe(reader, writer):
raw_socket = reader.extra_attributes[anyio.abc.SocketAttribute.raw_socket]()
try:
while True:
try:
data = await reader.receive()
except (
anyio.EndOfStream,
anyio.ClosedResourceError,
anyio.BrokenResourceError,
):
break
try:
await writer.send(data)
except (
anyio.ClosedResourceError,
anyio.BrokenResourceError,
):
break
finally:
try:
try:
raw_socket.shutdown(socket.SHUT_RD)
except OSError:
pass
finally:
await writer.send_eof()
async def handle(client):
async with client:
buffered = anyio.streams.buffered.BufferedByteReceiveStream(client)
with anyio.move_on_after(30):
try:
header = await buffered.receive_exactly(5)
except anyio.IncompleteRead:
print("incomplete read")
return
if header[0:3] != b"\x16\x03\x01":
print("wrong tls version, or not handshake")
return
record_length = int.from_bytes(header[3:5], byteorder="big", signed=False)
try:
record = await buffered.receive_exactly(record_length)
except anyio.IncompleteRead:
print("incomplete read 2")
return
if record[0] != 0x01:
print("not client hello")
print(record)
length = int.from_bytes(record[1:4], byteorder="big", signed=False)
if record[4:6] != b"\x03\x03":
print(f"bad client version, {record[4:6]}")
return
client_random = record[6:38]
print(len(client_random))
session_id_length = record[38]
print(session_id_length)
session_id = record[39 : 39 + session_id_length]
print(session_id)
cipher_suites_len_start = 39 + session_id_length
cipher_suites_length = int.from_bytes(
record[cipher_suites_len_start : cipher_suites_len_start + 2],
byteorder="big",
signed=False,
)
print(cipher_suites_length)
cipher_suites_start = cipher_suites_len_start + 2
compression_methods_start = cipher_suites_start + cipher_suites_length
cipher_suites = record[cipher_suites_start:compression_methods_start]
print(cipher_suites)
extensions_len_start = compression_methods_start + 2
if (
record[compression_methods_start : compression_methods_start + 2]
!= b"\x01\x00"
):
print("compression enabled")
return
if record_length <= extensions_len_start:
print("no extensions")
return
extensions_start = extensions_len_start + 2
extensions_length = int.from_bytes(
record[extensions_len_start:extensions_start],
byteorder="big",
signed=False,
)
print(extensions_start)
if extensions_start + extensions_length != record_length:
print("extensions are not the last field")
return
extension_types = set()
extension_type_start = extensions_start
sni = None
while extension_type_start < record_length:
extension_len_start = extension_type_start + 2
extension_type = record[extension_type_start:extension_len_start]
if extension_type in extension_types:
print("duplicate extension types")
return
extension_types.add(extension_type)
extension_start = extension_len_start + 2
extension_length = int.from_bytes(
record[extension_len_start:extension_start],
byteorder="big",
signed=False,
)
extension_type_start = extension_start + extension_length
if extension_type == b"\x00\x00":
sni = record[extension_start:extension_type_start]
if extension_type_start != record_length:
print("too many extensions")
return
if sni is None:
print("no sni")
return
sni_extension_length = len(sni)
sni_entry_types = set()
sni_entry_type_start = 2
sni_list_length = int.from_bytes(
sni[0:sni_entry_type_start], byteorder="big", signed=False
)
if sni_entry_type_start + sni_list_length != sni_extension_length:
print("invalid sni length")
return
domain_name = None
while sni_entry_type_start < sni_extension_length:
sni_entry_type = sni[sni_entry_type_start]
if sni_entry_type in sni_entry_types:
print("duplicate sni")
return
sni_entry_types.add(sni_entry_type)
sni_entry_len_start = sni_entry_type_start + 1
sni_entry_start = sni_entry_len_start + 2
sni_entry_length = int.from_bytes(
sni[sni_entry_len_start:sni_entry_start],
byteorder="big",
signed=False,
)
sni_entry_type_start = sni_entry_start + sni_entry_length
if sni_entry_type == 0:
domain_name = sni[sni_entry_start:sni_entry_type_start]
if sni_entry_type_start != sni_extension_length:
print("too much sni")
return
print(domain_name)
print(length, record_length)
DEFAULT_RECEIVE_SIZE = 65536
async def send():
try:
await server.send(header + record + buffered.buffer)
except (
anyio.ClosedResourceError,
anyio.BrokenResourceError,
):
return
await pipe(writer=server, reader=client)
async with (
await anyio.connect_tcp(domain_name.decode("ascii"), 443) as server,
anyio.create_task_group() as tg,
):
tg.start_soon(send)
tg.start_soon(functools.partial(pipe, reader=server, writer=client))
async def main():
listener = await anyio.create_tcp_listener(local_port=1234)
await listener.serve(handle)
anyio.run(main, backend="trio")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment