Skip to content

Instantly share code, notes, and snippets.

@conectado
Created August 15, 2024 03:18
Show Gist options
  • Save conectado/3a920134abacf24b1fc89bb08c08812e to your computer and use it in GitHub Desktop.
Save conectado/3a920134abacf24b1fc89bb08c08812e to your computer and use it in GitHub Desktop.
stun-tests.rs
use anyhow::Context;
use bytecodec::{DecodeExt, EncodeExt as _};
use futures::{Future, FutureExt};
use std::{
collections::VecDeque,
net::{SocketAddr, ToSocketAddrs},
pin::Pin,
task::{ready, Poll, Waker},
time::{Duration, Instant},
};
use stun_codec::rfc5389::{attributes::XorMappedAddress, methods::BINDING, Attribute};
use stun_codec::*;
use tokio::net::UdpSocket;
pub fn make_binding_request() -> Vec<u8> {
let request = Message::<Attribute>::new(
MessageClass::Request,
BINDING,
TransactionId::new(rand::random()),
);
MessageEncoder::<Attribute>::default()
.encode_into_bytes(request)
.unwrap()
}
pub fn parse_binding_response(buf: &[u8]) -> SocketAddr {
let message = MessageDecoder::<Attribute>::default()
.decode_from_bytes(buf)
.unwrap()
.unwrap();
message
.get_attribute::<XorMappedAddress>()
.unwrap()
.address()
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
let mut servers = Vec::new();
let server = "stun.cloudflare.com:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "34.100.204.176:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "34.118.92.215:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "35.199.88.84:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "34.80.109.129:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "35.189.51.106:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "34.89.78.71:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "35.231.210.194:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "34.102.65.13:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "34.135.13.105:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "34.94.219.196:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "34.38.30.124:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "35.243.253.170:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let server = "34.88.39.18:3478"
.to_socket_addrs()?
.next()
.context("Failed to resolve hostname")?;
servers.push(server);
let mut binding = StunBinding::new(servers);
let mut timer = Timer::default();
loop {
for _ in 0..100 {
socket
.send_to(
&[0xab],
"10.20.30.40:1234"
.to_socket_addrs()
.unwrap()
.next()
.unwrap(),
)
.await?;
}
if let Some(transmit) = binding.poll_transmit() {
socket.send_to(&transmit.payload, transmit.dst).await?;
continue;
}
let mut buf = vec![0u8; 100];
tokio::select! {
Some(time) = &mut timer => {
binding.handle_timeout(time);
},
res = socket.recv(&mut buf) => {
let num_read = res?;
binding.handle_input(&buf[..num_read], Instant::now());
}
}
timer.reset_to(binding.poll_timeout());
if let Some(address) = binding.public_address() {
println!("Our public IP is: {address}");
}
}
}
#[derive(Default)]
struct Timer {
inner: Option<Pin<Box<tokio::time::Sleep>>>,
waker: Option<Waker>,
}
impl Timer {
fn reset_to(&mut self, next: Option<Instant>) {
let next = match next {
Some(next) => next,
None => {
self.inner = None;
return;
}
};
match self.inner.as_mut() {
Some(timer) => timer.as_mut().reset(next.into()),
None => {
self.inner = Some(Box::pin(tokio::time::sleep_until(next.into())));
if let Some(waker) = self.waker.take() {
waker.wake()
}
}
}
}
}
impl Future for Timer {
type Output = Option<Instant>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut();
let Some(timer) = this.inner.as_mut() else {
self.waker = Some(cx.waker().clone());
return Poll::Ready(None);
};
ready!(timer.as_mut().poll_unpin(cx));
Poll::Ready(Some(timer.as_ref().deadline().into()))
}
}
struct StunBinding {
servers: Vec<SocketAddr>,
state: State,
buffered_transmits: VecDeque<Transmit>,
}
impl StunBinding {
fn new(servers: Vec<SocketAddr>) -> Self {
let mut buffered_transmits = VecDeque::new();
for server in &servers {
buffered_transmits.push_back(Transmit {
dst: *server,
payload: make_binding_request(),
});
}
Self {
servers,
state: State::Sent,
buffered_transmits,
}
}
fn handle_input(&mut self, packet: &[u8], now: Instant) {
let address = parse_binding_response(packet);
self.state = State::Received { address, at: now };
}
fn poll_transmit(&mut self) -> Option<Transmit> {
self.buffered_transmits.pop_front()
}
/// Notifies `StunBinding` that time has advanced to `now`.
fn handle_timeout(&mut self, now: Instant) {
let last_received_at = match self.state {
State::Sent => return,
State::Received { at, .. } => at,
};
if now.duration_since(last_received_at) < Duration::from_secs(5) {
return;
}
for server in &self.servers {
self.buffered_transmits.push_front(Transmit {
dst: *server,
payload: make_binding_request(),
});
}
self.state = State::Sent;
}
/// Returns the timestamp when we next expect `handle_timeout` to be called.
fn poll_timeout(&self) -> Option<Instant> {
match self.state {
State::Sent => None,
State::Received { at, .. } => Some(at + Duration::from_secs(5)),
}
}
fn public_address(&self) -> Option<SocketAddr> {
match self.state {
State::Sent => None,
State::Received { address, .. } => Some(address),
}
}
}
enum State {
Sent,
Received { address: SocketAddr, at: Instant },
}
struct Transmit {
dst: SocketAddr,
payload: Vec<u8>,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment