Created
August 15, 2024 03:18
-
-
Save conectado/3a920134abacf24b1fc89bb08c08812e to your computer and use it in GitHub Desktop.
stun-tests.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use anyhow::Context; | |
use bytecodec::{DecodeExt, EncodeExt as _}; | |
use futures::{Future, FutureExt}; | |
use std::{ | |
collections::VecDeque, | |
net::{SocketAddr, ToSocketAddrs}, | |
pin::Pin, | |
task::{ready, Poll, Waker}, | |
time::{Duration, Instant}, | |
}; | |
use stun_codec::rfc5389::{attributes::XorMappedAddress, methods::BINDING, Attribute}; | |
use stun_codec::*; | |
use tokio::net::UdpSocket; | |
pub fn make_binding_request() -> Vec<u8> { | |
let request = Message::<Attribute>::new( | |
MessageClass::Request, | |
BINDING, | |
TransactionId::new(rand::random()), | |
); | |
MessageEncoder::<Attribute>::default() | |
.encode_into_bytes(request) | |
.unwrap() | |
} | |
pub fn parse_binding_response(buf: &[u8]) -> SocketAddr { | |
let message = MessageDecoder::<Attribute>::default() | |
.decode_from_bytes(buf) | |
.unwrap() | |
.unwrap(); | |
message | |
.get_attribute::<XorMappedAddress>() | |
.unwrap() | |
.address() | |
} | |
#[tokio::main] | |
async fn main() -> anyhow::Result<()> { | |
let socket = UdpSocket::bind("0.0.0.0:0").await?; | |
let mut servers = Vec::new(); | |
let server = "stun.cloudflare.com:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "34.100.204.176:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "34.118.92.215:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "35.199.88.84:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "34.80.109.129:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "35.189.51.106:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "34.89.78.71:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "35.231.210.194:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "34.102.65.13:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "34.135.13.105:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "34.94.219.196:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "34.38.30.124:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "35.243.253.170:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let server = "34.88.39.18:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?; | |
servers.push(server); | |
let mut binding = StunBinding::new(servers); | |
let mut timer = Timer::default(); | |
loop { | |
for _ in 0..100 { | |
socket | |
.send_to( | |
&[0xab], | |
"10.20.30.40:1234" | |
.to_socket_addrs() | |
.unwrap() | |
.next() | |
.unwrap(), | |
) | |
.await?; | |
} | |
if let Some(transmit) = binding.poll_transmit() { | |
socket.send_to(&transmit.payload, transmit.dst).await?; | |
continue; | |
} | |
let mut buf = vec![0u8; 100]; | |
tokio::select! { | |
Some(time) = &mut timer => { | |
binding.handle_timeout(time); | |
}, | |
res = socket.recv(&mut buf) => { | |
let num_read = res?; | |
binding.handle_input(&buf[..num_read], Instant::now()); | |
} | |
} | |
timer.reset_to(binding.poll_timeout()); | |
if let Some(address) = binding.public_address() { | |
println!("Our public IP is: {address}"); | |
} | |
} | |
} | |
#[derive(Default)] | |
struct Timer { | |
inner: Option<Pin<Box<tokio::time::Sleep>>>, | |
waker: Option<Waker>, | |
} | |
impl Timer { | |
fn reset_to(&mut self, next: Option<Instant>) { | |
let next = match next { | |
Some(next) => next, | |
None => { | |
self.inner = None; | |
return; | |
} | |
}; | |
match self.inner.as_mut() { | |
Some(timer) => timer.as_mut().reset(next.into()), | |
None => { | |
self.inner = Some(Box::pin(tokio::time::sleep_until(next.into()))); | |
if let Some(waker) = self.waker.take() { | |
waker.wake() | |
} | |
} | |
} | |
} | |
} | |
impl Future for Timer { | |
type Output = Option<Instant>; | |
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { | |
let mut this = self.as_mut(); | |
let Some(timer) = this.inner.as_mut() else { | |
self.waker = Some(cx.waker().clone()); | |
return Poll::Ready(None); | |
}; | |
ready!(timer.as_mut().poll_unpin(cx)); | |
Poll::Ready(Some(timer.as_ref().deadline().into())) | |
} | |
} | |
struct StunBinding { | |
servers: Vec<SocketAddr>, | |
state: State, | |
buffered_transmits: VecDeque<Transmit>, | |
} | |
impl StunBinding { | |
fn new(servers: Vec<SocketAddr>) -> Self { | |
let mut buffered_transmits = VecDeque::new(); | |
for server in &servers { | |
buffered_transmits.push_back(Transmit { | |
dst: *server, | |
payload: make_binding_request(), | |
}); | |
} | |
Self { | |
servers, | |
state: State::Sent, | |
buffered_transmits, | |
} | |
} | |
fn handle_input(&mut self, packet: &[u8], now: Instant) { | |
let address = parse_binding_response(packet); | |
self.state = State::Received { address, at: now }; | |
} | |
fn poll_transmit(&mut self) -> Option<Transmit> { | |
self.buffered_transmits.pop_front() | |
} | |
/// Notifies `StunBinding` that time has advanced to `now`. | |
fn handle_timeout(&mut self, now: Instant) { | |
let last_received_at = match self.state { | |
State::Sent => return, | |
State::Received { at, .. } => at, | |
}; | |
if now.duration_since(last_received_at) < Duration::from_secs(5) { | |
return; | |
} | |
for server in &self.servers { | |
self.buffered_transmits.push_front(Transmit { | |
dst: *server, | |
payload: make_binding_request(), | |
}); | |
} | |
self.state = State::Sent; | |
} | |
/// Returns the timestamp when we next expect `handle_timeout` to be called. | |
fn poll_timeout(&self) -> Option<Instant> { | |
match self.state { | |
State::Sent => None, | |
State::Received { at, .. } => Some(at + Duration::from_secs(5)), | |
} | |
} | |
fn public_address(&self) -> Option<SocketAddr> { | |
match self.state { | |
State::Sent => None, | |
State::Received { address, .. } => Some(address), | |
} | |
} | |
} | |
enum State { | |
Sent, | |
Received { address: SocketAddr, at: Instant }, | |
} | |
struct Transmit { | |
dst: SocketAddr, | |
payload: Vec<u8>, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment