Skip to content

Instantly share code, notes, and snippets.

@aksh-at
Last active May 4, 2025 12:51
Show Gist options
  • Save aksh-at/e85a5517610a1a2bff35fac41d4c982f to your computer and use it in GitHub Desktop.
Save aksh-at/e85a5517610a1a2bff35fac41d4c982f to your computer and use it in GitHub Desktop.
Modal QUIC NAT hole-punching
"""
Proof-of-concept for NAT traversal and low-latency communication over QUIC
between two Modal containers.
In theory this could be used to establish a low-latency p2p connection between a
service running outside Modal and a Modal GPU container, e.g. for real-time
inference on a video stream. Please let us know if you try it!
Usage:
> modal run modal_quic_hole_punch.py
"""
from typing import Literal
import modal
import time
app = modal.App("quic-hole-punch")
image = (
modal.Image.debian_slim()
.pip_install("fastapi", "aioquic", "aiohttp", "six")
.pip_install("pynat")
)
@app.function(image=image, max_containers=1)
@modal.asgi_app()
def rendezvous():
"""Rendezvous server that hands each peer the other's public tuple."""
from typing import Dict, Optional, Tuple
from fastapi import FastAPI
from pydantic import BaseModel
class RegisterRequest(BaseModel):
peer_id: Literal["A", "B"]
ip: str
port: int
api = FastAPI()
peers: Dict[str, Tuple[str, int]] = {}
@api.post("/register")
async def register(req: RegisterRequest):
peers[req.peer_id] = (req.ip, req.port)
other = "A" if req.peer_id == "B" else "B"
info: Optional[Tuple[str, int]] = peers.get(other)
return {"peer": info} # Null until the second peer registers
return api
async def get_ext_addr(sock):
from pynat import get_stun_response
response = get_stun_response(sock, ("stun.ekiga.net", 3478))
return response["ext_ip"], response["ext_port"]
def create_cert(key):
"""Create a self-signed certificate for the given key."""
import datetime
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.x509.oid import NameOID
return (
x509.CertificateBuilder()
.subject_name(
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "modal-quic-demo")])
)
.issuer_name(
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "modal-quic-demo")])
)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=1))
.sign(key, hashes.SHA256())
)
N_PINGS = 5
@app.function(image=image, region="jp") # Run in 🇯🇵.
async def punch_and_quic(my_id: str, rendezvous_url: str, local_port: int = 5555):
import asyncio
import socket
import ssl
import aiohttp
from aioquic.asyncio import connect, serve
from aioquic.quic.configuration import QuicConfiguration
from cryptography.hazmat.primitives.asymmetric import ec
# 1. Discover public mapping via STUN.
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(("0.0.0.0", local_port))
sock.setblocking(False)
pub_ip, pub_port = await get_ext_addr(sock)
print(f"[{my_id}] Pub IP: {pub_ip}, Pub Port: {pub_port}")
# 2. Register & wait for the peer's tuple.
async with aiohttp.ClientSession() as s:
while True:
resp = await s.post(
f"{rendezvous_url}/register",
json={"peer_id": my_id, "ip": pub_ip, "port": pub_port},
)
if peer := (await resp.json()).get("peer"):
peer_ip, peer_port = peer
break
await asyncio.sleep(1)
print(f"[{my_id}] Punching {pub_ip}:{pub_port} -> {peer_ip}:{peer_port}")
for _ in range(50): # 5s total.
sock.sendto(b"punch", (peer_ip, peer_port))
try:
await asyncio.wait_for(asyncio.get_event_loop().sock_recv(sock, 16), 0.1)
break
except asyncio.TimeoutError:
continue
else:
raise RuntimeError("Hole punching failed – no response from peer")
print(f"[{my_id}] Punched {pub_ip}:{pub_port} -> {peer_ip}:{peer_port}")
sock.close() # Close socket. Mapping should stay alive.
is_client = my_id == "B"
cfg = QuicConfiguration(
is_client=is_client, alpn_protocols=["hq-29"], verify_mode=ssl.CERT_NONE
)
if not is_client:
cfg.private_key = ec.generate_private_key(ec.SECP256R1())
cfg.certificate = create_cert(cfg.private_key)
async def echo(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
for i in range(N_PINGS):
data = await reader.read(100)
if not data:
break
assert data == b"ping"
writer.write(b"pong")
await writer.drain()
writer.close()
def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
asyncio.create_task(echo(reader, writer))
await serve(
host="0.0.0.0",
port=local_port, # Use the punched port.
configuration=cfg,
stream_handler=handler,
)
await asyncio.sleep(1)
else:
async with connect(
peer_ip,
peer_port,
configuration=cfg,
local_port=local_port,
) as quic:
reader, writer = await quic.create_stream()
total_latency = 0
for i in range(N_PINGS):
start_time = time.monotonic()
writer.write(b"ping")
await writer.drain()
print(f"[{my_id}] Sent ping {i + 1}")
response = await reader.read(100)
assert response == b"pong"
end_time = time.monotonic()
rtt = end_time - start_time
total_latency += rtt
print(f"[{my_id}] Received pong {i + 1}")
print(f"[{my_id}] Round-trip time: {rtt * 1000:.2f}ms")
await asyncio.sleep(0.1)
writer.close()
print(f"[{my_id}] Average rtt: {(total_latency / N_PINGS) * 1000:.2f}ms")
@app.local_entrypoint()
def main():
a = punch_and_quic.spawn(my_id="A", rendezvous_url=rendezvous.web_url)
b = punch_and_quic.spawn(my_id="B", rendezvous_url=rendezvous.web_url)
modal.FunctionCall.gather(a, b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment