Created
May 4, 2023 10:42
-
-
Save lithdew/4a009884829ad4c0ce41e3c3e8bf039b to your computer and use it in GitHub Desktop.
rust (quinn, rustls): quic holepunching w/ basic stun client
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "quic-holepunching" | |
version = "0.1.0" | |
edition = "2021" | |
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | |
[dependencies] | |
anyhow = "1.0.71" | |
bitflags = "2.2.1" | |
num_enum = "0.6.1" | |
quinn = { git = "https://github.com/quinn-rs/quinn.git", rev = "99e462f8c868b1ac511a360a496e90a73a878f07", features = ["runtime-tokio"] } | |
rcgen = "0.10.0" | |
rustls = { version = "0.21.0", features = ["quic", "dangerous_configuration"] } | |
serde = { version = "1.0.160", features = ["derive"] } | |
serde_json = "1.0.96" | |
tokio = { version = "1.28.0", features = ["full"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use anyhow::Result; | |
mod stun { | |
use core::fmt::Debug; | |
pub const MAGIC_COOKIE: u32 = 0x2112a442u32; | |
pub const MAGIC_COOKIE_BYTES: [u8; 4] = MAGIC_COOKIE.to_be_bytes(); | |
#[derive( | |
Debug, | |
Copy, | |
Clone, | |
Eq, | |
PartialEq, | |
num_enum::FromPrimitive, | |
num_enum::IntoPrimitive, | |
serde::Serialize, | |
serde::Deserialize, | |
)] | |
#[repr(u16)] | |
#[serde(rename_all = "snake_case")] | |
pub enum Class { | |
Request = 0x0000, | |
Indication = 0x0010, | |
Success = 0x0100, | |
Error = 0x0110, | |
#[num_enum(catch_all)] | |
Unknown(u16), | |
} | |
#[derive( | |
Debug, | |
Copy, | |
Clone, | |
Eq, | |
PartialEq, | |
num_enum::FromPrimitive, | |
num_enum::IntoPrimitive, | |
serde::Serialize, | |
serde::Deserialize, | |
)] | |
#[repr(u8)] | |
#[serde(rename_all = "snake_case")] | |
pub enum Method { | |
Binding = 0x0001, | |
#[num_enum(catch_all)] | |
Unknown(u8), | |
} | |
#[derive( | |
Debug, | |
Copy, | |
Clone, | |
Eq, | |
PartialEq, | |
num_enum::FromPrimitive, | |
num_enum::IntoPrimitive, | |
serde::Serialize, | |
serde::Deserialize, | |
)] | |
#[repr(u16)] | |
#[serde(rename_all = "snake_case")] | |
pub enum Type { | |
MappedAddress = 0x0001, | |
Username = 0x0006, | |
MessageIntegrity = 0x0008, | |
ErrorCode = 0x0009, | |
UnknownAttributes = 0x000a, | |
Realm = 0x0014, | |
Nonce = 0x0015, | |
XorMappedAddress = 0x0020, | |
Software = 0x8022, | |
AlternateServer = 0x8023, | |
Fingerprint = 0x8028, | |
#[num_enum(catch_all)] | |
Unknown(u16), | |
} | |
#[derive( | |
Debug, | |
Copy, | |
Clone, | |
Eq, | |
PartialEq, | |
num_enum::FromPrimitive, | |
num_enum::IntoPrimitive, | |
serde::Serialize, | |
serde::Deserialize, | |
)] | |
#[repr(u8)] | |
#[serde(rename_all = "lowercase")] | |
pub enum AddressFamily { | |
IPv4 = 0x01, | |
IPv6 = 0x02, | |
#[num_enum(catch_all)] | |
Unknown(u8), | |
} | |
#[derive(Copy, Clone, serde::Deserialize)] | |
#[repr(transparent)] | |
#[serde(rename_all = "camelCase")] | |
pub struct HeaderFlags(pub u16); | |
impl HeaderFlags { | |
pub fn class(self) -> Class { | |
Class::from(self.0 & 0x0110) | |
} | |
pub fn method(self) -> Method { | |
Method::from((self.0 & 0x0001) as u8) | |
} | |
} | |
impl TryFrom<u16> for HeaderFlags { | |
type Error = anyhow::Error; | |
fn try_from(value: u16) -> std::result::Result<Self, Self::Error> { | |
if value & 0xc000 != 0x0000 { | |
return Err(std::io::Error::new( | |
std::io::ErrorKind::InvalidData, | |
"Invalid header flags", | |
) | |
.into()); | |
} | |
Ok(HeaderFlags(value)) | |
} | |
} | |
impl Debug for HeaderFlags { | |
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { | |
f.debug_struct("HeaderFlags") | |
.field("class", &self.class()) | |
.field("method", &self.method()) | |
.finish() | |
} | |
} | |
impl serde::Serialize for HeaderFlags { | |
fn serialize<S: serde::Serializer>( | |
&self, | |
serializer: S, | |
) -> core::result::Result<S::Ok, S::Error> { | |
use serde::ser::SerializeStruct; | |
let mut flags = serializer.serialize_struct("HeaderFlags", 2)?; | |
flags.serialize_field("class", &self.class())?; | |
flags.serialize_field("method", &self.method())?; | |
flags.end() | |
} | |
} | |
#[repr(C)] | |
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)] | |
#[serde(rename_all = "camelCase")] | |
pub struct Header { | |
pub flags: HeaderFlags, | |
pub length: u16, | |
pub magic_cookie: u32, | |
pub transaction_id: [u32; 3], | |
} | |
impl Header { | |
pub const LENGTH: usize = 20; | |
} | |
const _: [(); std::mem::size_of::<Header>()] = [(); Header::LENGTH]; | |
impl TryFrom<&[u8]> for Header { | |
type Error = anyhow::Error; | |
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> { | |
if value.len() < Header::LENGTH { | |
return Err(std::io::Error::new( | |
std::io::ErrorKind::InvalidData, | |
"Invalid header length", | |
) | |
.into()); | |
} | |
let mut header = unsafe { std::ptr::read_unaligned(value.as_ptr() as *const Header) }; | |
header.flags = HeaderFlags(u16::from_be(header.flags.0)); | |
header.length = u16::from_be(header.length); | |
Ok(header) | |
} | |
} | |
impl From<Header> for [u8; Header::LENGTH] { | |
fn from(val: Header) -> [u8; Header::LENGTH] { | |
let mut header = val; | |
header.flags = HeaderFlags(header.flags.0.to_be()); | |
header.length = header.length.to_be(); | |
let mut buffer = [0u8; Header::LENGTH]; | |
unsafe { | |
std::ptr::copy_nonoverlapping( | |
&header as *const Header as *const u8, | |
buffer.as_mut_ptr(), | |
Header::LENGTH, | |
); | |
} | |
buffer | |
} | |
} | |
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] | |
#[serde(rename_all = "lowercase")] | |
pub enum MappedAddress { | |
IPv4(std::net::SocketAddrV4), | |
IPv6(std::net::SocketAddrV6), | |
Unknown { | |
family: AddressFamily, | |
port: u16, | |
address: Vec<u8>, | |
}, | |
} | |
impl MappedAddress { | |
pub const MIN_LENGTH: usize = 8; | |
pub fn xor(&self) -> Self { | |
let mut address = self.clone(); | |
match &mut address { | |
Self::IPv4(address) => { | |
address.set_port(address.port() ^ (MAGIC_COOKIE >> 16) as u16); | |
let mut octets = address.ip().octets(); | |
for i in 0..octets.len() { | |
octets[i] ^= MAGIC_COOKIE_BYTES[i % 4]; | |
} | |
address.set_ip(std::net::Ipv4Addr::from(octets)); | |
} | |
Self::IPv6(address) => { | |
address.set_port(address.port() ^ (MAGIC_COOKIE >> 16) as u16); | |
let mut octets = address.ip().octets(); | |
for i in 0..octets.len() { | |
octets[i] ^= MAGIC_COOKIE_BYTES[i % 4]; | |
} | |
address.set_ip(std::net::Ipv6Addr::from(octets)); | |
} | |
Self::Unknown { port, address, .. } => { | |
*port ^= (MAGIC_COOKIE >> 16) as u16; | |
for i in 0..address.len() { | |
address[i] ^= MAGIC_COOKIE_BYTES[i % 4]; | |
} | |
} | |
} | |
address | |
} | |
} | |
impl TryFrom<MappedAddress> for std::net::SocketAddr { | |
type Error = std::io::Error; | |
fn try_from(value: MappedAddress) -> std::result::Result<Self, Self::Error> { | |
match value { | |
MappedAddress::IPv4(address) => Ok(std::net::SocketAddr::V4(address)), | |
MappedAddress::IPv6(address) => Ok(std::net::SocketAddr::V6(address)), | |
MappedAddress::Unknown { .. } => Err(std::io::Error::new( | |
std::io::ErrorKind::InvalidData, | |
"Unknown address family", | |
)), | |
} | |
} | |
} | |
impl TryFrom<&[u8]> for MappedAddress { | |
type Error = anyhow::Error; | |
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> { | |
if value.len() < MappedAddress::MIN_LENGTH { | |
return Err( | |
std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid length").into(), | |
); | |
} | |
let family = AddressFamily::from(value[1]); | |
let port = u16::from_be_bytes([value[2], value[3]]); | |
Ok(match family { | |
AddressFamily::IPv4 => { | |
let address = [value[4], value[5], value[6], value[7]]; | |
Self::IPv4(std::net::SocketAddrV4::new(address.into(), port)) | |
} | |
AddressFamily::IPv6 => { | |
if value.len() - MappedAddress::MIN_LENGTH != 16 { | |
return Err(std::io::Error::new( | |
std::io::ErrorKind::InvalidData, | |
"Invalid length", | |
) | |
.into()); | |
} | |
let address = [ | |
value[4], value[5], value[6], value[7], value[8], value[9], value[10], | |
value[11], value[12], value[13], value[14], value[15], value[16], | |
value[17], value[18], value[19], | |
]; | |
Self::IPv6(std::net::SocketAddrV6::new(address.into(), port, 0, 0)) | |
} | |
AddressFamily::Unknown(_) => Self::Unknown { | |
family, | |
port, | |
address: value[4..].to_vec(), | |
}, | |
}) | |
} | |
} | |
#[derive(Debug, serde::Serialize, serde::Deserialize)] | |
#[serde(rename_all = "camelCase")] | |
pub enum AttributeValue { | |
MappedAddress(MappedAddress), | |
XorMappedAddress(MappedAddress), | |
Software { value: String }, | |
AlternateServer(MappedAddress), | |
Fingerprint { value: u32 }, | |
Unknown(Vec<u8>), | |
} | |
#[derive(Debug, serde::Serialize, serde::Deserialize)] | |
#[serde(rename_all = "camelCase")] | |
pub struct Attribute { | |
pub length: u16, | |
pub value: AttributeValue, | |
} | |
impl Attribute { | |
pub const HEADER_LENGTH: usize = 4; | |
} | |
impl TryFrom<&[u8]> for Attribute { | |
type Error = anyhow::Error; | |
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> { | |
if value.len() < Attribute::HEADER_LENGTH { | |
return Err(std::io::Error::new( | |
std::io::ErrorKind::InvalidData, | |
"Invalid attribute length", | |
) | |
.into()); | |
} | |
let r#type = Type::from(u16::from_be_bytes([value[0], value[1]])); | |
let length = u16::from_be_bytes([value[2], value[3]]); | |
let value = &value[Attribute::HEADER_LENGTH..][..length as usize]; | |
Ok(Attribute { | |
length, | |
value: match r#type { | |
Type::MappedAddress => { | |
AttributeValue::MappedAddress(MappedAddress::try_from(value)?) | |
} | |
Type::XorMappedAddress => { | |
AttributeValue::XorMappedAddress(MappedAddress::try_from(value)?.xor()) | |
} | |
Type::Software => { | |
let value = std::str::from_utf8(value)?; | |
AttributeValue::Software { | |
value: value.to_owned(), | |
} | |
} | |
Type::AlternateServer => { | |
AttributeValue::AlternateServer(MappedAddress::try_from(value)?) | |
} | |
Type::Fingerprint => { | |
if value.len() != 4 { | |
return Err(std::io::Error::new( | |
std::io::ErrorKind::InvalidData, | |
"Invalid fingerprint length", | |
) | |
.into()); | |
} | |
let value = u32::from_be_bytes([value[0], value[1], value[2], value[3]]); | |
AttributeValue::Fingerprint { value } | |
} | |
_ => AttributeValue::Unknown(value.to_owned()), | |
}, | |
}) | |
} | |
} | |
#[derive(Debug, serde::Serialize, serde::Deserialize)] | |
#[serde(rename_all = "camelCase")] | |
pub struct Message { | |
pub header: Header, | |
pub attributes: Vec<Attribute>, | |
} | |
impl Message { | |
pub const MAX_LENGTH: usize = 548; | |
} | |
impl TryFrom<&[u8]> for Message { | |
type Error = anyhow::Error; | |
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> { | |
let header = Header::try_from(value)?; | |
let mut attributes = Vec::new(); | |
let mut offset = Header::LENGTH; | |
while offset < value.len() { | |
let attribute = Attribute::try_from(&value[offset..])?; | |
offset += Attribute::HEADER_LENGTH + attribute.length as usize; | |
attributes.push(attribute); | |
} | |
Ok(Self { header, attributes }) | |
} | |
} | |
} | |
struct SkipServerVerification; | |
impl SkipServerVerification { | |
fn new() -> std::sync::Arc<Self> { | |
std::sync::Arc::new(Self) | |
} | |
} | |
impl rustls::client::ServerCertVerifier for SkipServerVerification { | |
fn verify_server_cert( | |
&self, | |
_end_entity: &rustls::Certificate, | |
_intermediates: &[rustls::Certificate], | |
_server_name: &rustls::ServerName, | |
_scts: &mut dyn Iterator<Item = &[u8]>, | |
_ocsp_response: &[u8], | |
_now: std::time::SystemTime, | |
) -> Result<rustls::client::ServerCertVerified, rustls::Error> { | |
Ok(rustls::client::ServerCertVerified::assertion()) | |
} | |
} | |
fn configure_client() -> quinn::ClientConfig { | |
let crypto = rustls::ClientConfig::builder() | |
.with_safe_defaults() | |
.with_custom_certificate_verifier(SkipServerVerification::new()) | |
.with_no_client_auth(); | |
quinn::ClientConfig::new(std::sync::Arc::new(crypto)) | |
} | |
fn configure_server() -> Result<quinn::ServerConfig> { | |
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])?; | |
let cert_der = cert.serialize_der()?; | |
let priv_key = cert.serialize_private_key_der(); | |
let priv_key = rustls::PrivateKey(priv_key); | |
let cert_chain = vec![rustls::Certificate(cert_der)]; | |
let mut server_config = quinn::ServerConfig::with_single_cert(cert_chain, priv_key)?; | |
let transport_config = std::sync::Arc::get_mut(&mut server_config.transport).unwrap(); | |
transport_config.max_concurrent_uni_streams(0_u8.into()); | |
Ok(server_config) | |
} | |
#[tokio::main] | |
async fn main() -> Result<()> { | |
let runtime = quinn::default_runtime() | |
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "no async runtime found"))?; | |
let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?; | |
println!("Bound UDP socket to: {}", socket.local_addr()?); | |
let stun_request: [u8; stun::Header::LENGTH] = (stun::Header { | |
flags: stun::HeaderFlags(Into::<u8>::into(stun::Method::Binding) as u16), | |
length: 0, | |
magic_cookie: stun::MAGIC_COOKIE, | |
transaction_id: [0u32; 3], | |
}) | |
.into(); | |
socket | |
.send_to(&stun_request, "stun.l.google.com:19302") | |
.await?; | |
let message = { | |
let mut buffer = [0u8; stun::Message::MAX_LENGTH]; | |
let (num_read, _) = socket.recv_from(&mut buffer).await?; | |
stun::Message::try_from(&buffer[..num_read])? | |
}; | |
println!("{}", serde_json::to_string_pretty(&message)?); | |
let external_address = message | |
.attributes | |
.iter() | |
.find_map(|attribute| match &attribute.value { | |
stun::AttributeValue::MappedAddress(address) => Some(address), | |
stun::AttributeValue::XorMappedAddress(address) => Some(address), | |
_ => None, | |
}) | |
.unwrap(); | |
println!( | |
"Send this over: {}", | |
serde_json::to_string(external_address)? | |
); | |
use std::io::Write; | |
std::io::stdout().write_all(b"Provide the JSON from the other side: ")?; | |
std::io::stdout().flush()?; | |
let mut peer_address_json = String::new(); | |
std::io::stdin().read_line(&mut peer_address_json)?; | |
let peer_address: stun::MappedAddress = serde_json::from_str(&peer_address_json)?; | |
let mut endpoint = quinn::Endpoint::new( | |
quinn::EndpointConfig::default(), | |
Some(configure_server()?), | |
socket.into_std()?, | |
runtime, | |
)?; | |
endpoint.set_default_client_config(configure_client()); | |
let client_endpoint = endpoint.clone(); | |
let server_endpoint = endpoint.clone(); | |
tokio::spawn(async move { | |
while let Ok(peer) = | |
client_endpoint.connect(peer_address.to_owned().try_into().unwrap(), "localhost") | |
{ | |
let peer = match peer.await { | |
Ok(peer) => peer, | |
Err(e) => { | |
println!("Error establishing connection on outgoing socket: {}", e); | |
continue; | |
} | |
}; | |
let (mut send, mut recv) = match peer.open_bi().await { | |
Err(quinn::ConnectionError::ApplicationClosed { .. }) => return, | |
Err(err) => { | |
println!( | |
"Error opening bidirectional stream on outgoing socket: {}", | |
err | |
); | |
return; | |
} | |
Ok(stream) => stream, | |
}; | |
use tokio::io::AsyncWriteExt; | |
tokio::spawn(async move { | |
let mut buffer = [0u8; 64 * 1024]; | |
loop { | |
if let Err(e) = recv.read_exact(&mut buffer[.."PONG".len()]).await { | |
println!("Error reading from outgoing socket: {}", e); | |
return; | |
} | |
println!( | |
"Got '{}' from: {:?}", | |
String::from_utf8_lossy(&buffer[.."PONG".len()]), | |
peer.remote_address() | |
); | |
} | |
}); | |
loop { | |
if let Err(e) = send.write_all("PING".as_bytes()).await { | |
println!("Error writing to outgoing socket: {}", e); | |
return; | |
} | |
if let Err(e) = send.flush().await { | |
println!("Error flushing writes to outgoing socket: {}", e); | |
return; | |
} | |
tokio::time::sleep(std::time::Duration::from_secs(1)).await; | |
} | |
} | |
}); | |
while let Some(peer) = server_endpoint.accept().await { | |
let peer = match peer.await { | |
Ok(peer) => peer, | |
Err(e) => { | |
println!("Error accepting connection on incoming socket: {}", e); | |
continue; | |
} | |
}; | |
println!("Peer {} connected.", peer.remote_address()); | |
use tokio::io::AsyncWriteExt; | |
tokio::spawn(async move { | |
let (mut send, mut recv) = match peer.accept_bi().await { | |
Err(quinn::ConnectionError::ApplicationClosed { .. }) => return, | |
Err(err) => { | |
println!( | |
"Error opening bidirectional stream on incoming socket: {}", | |
err | |
); | |
return; | |
} | |
Ok(stream) => stream, | |
}; | |
let mut buffer = [0u8; 64 * 1024]; | |
loop { | |
if let Err(e) = recv.read_exact(&mut buffer[.."PING".len()]).await { | |
println!("Error reading from incoming socket: {}", e); | |
return; | |
}; | |
println!( | |
"Got '{}' from: {:?}", | |
String::from_utf8_lossy(&buffer[.."PING".len()]), | |
peer.remote_address() | |
); | |
if let Err(e) = send.write_all("PONG".as_bytes()).await { | |
println!("Error writing to incoming socket: {}", e); | |
return; | |
} | |
if let Err(e) = send.flush().await { | |
println!("Error flushing writes to incoming socket: {}", e); | |
return; | |
} | |
} | |
}); | |
} | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment