Forked from conectado/gist:3a920134abacf24b1fc89bb08c08812e
Created
August 15, 2024 03:51
-
-
Save thomaseizinger/51f9c44eebd0dd3b58a169d63c4e203d to your computer and use it in GitHub Desktop.
stun-tests.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use anyhow::Context; | |
use bytecodec::{DecodeExt, EncodeExt as _}; | |
use futures::{Future, FutureExt}; | |
use std::{ | |
collections::VecDeque, | |
net::{SocketAddr, ToSocketAddrs}, | |
pin::Pin, | |
task::{ready, Poll, Waker}, | |
time::{Duration, Instant}, | |
}; | |
use stun_codec::rfc5389::{attributes::XorMappedAddress, methods::BINDING, Attribute}; | |
use stun_codec::*; | |
use tokio::net::UdpSocket; | |
pub fn make_binding_request() -> Vec<u8> { | |
let request = Message::<Attribute>::new( | |
MessageClass::Request, | |
BINDING, | |
TransactionId::new(rand::random()), | |
); | |
MessageEncoder::<Attribute>::default() | |
.encode_into_bytes(request) | |
.unwrap() | |
} | |
pub fn parse_binding_response(buf: &[u8]) -> SocketAddr { | |
let message = MessageDecoder::<Attribute>::default() | |
.decode_from_bytes(buf) | |
.unwrap() | |
.unwrap(); | |
message | |
.get_attribute::<XorMappedAddress>() | |
.unwrap() | |
.address() | |
} | |
#[tokio::main] | |
async fn main() -> anyhow::Result<()> { | |
let socket = UdpSocket::bind("0.0.0.0:0").await?; | |
let mut bindings = Vec::from([ | |
StunBinding::new( | |
"34.100.204.176:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"34.118.92.215:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"35.199.88.84:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"34.80.109.129:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"35.189.51.106:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"34.89.78.71:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"35.231.210.194:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"34.102.65.13:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"34.135.13.105:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"34.94.219.196:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"34.38.30.124:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"35.243.253.170:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
StunBinding::new( | |
"34.88.39.18:3478" | |
.to_socket_addrs()? | |
.next() | |
.context("Failed to resolve hostname")?, | |
), | |
]); | |
let mut timer = Timer::default(); | |
'outer: loop { | |
for binding in &mut bindings { | |
if let Some(transmit) = binding.poll_transmit() { | |
socket.send_to(&transmit.payload, transmit.dst).await?; | |
continue 'outer; | |
} | |
} | |
let mut buf = vec![0u8; 100]; | |
tokio::select! { | |
Some(time) = &mut timer => { | |
for binding in &mut bindings { | |
binding.handle_timeout(time); | |
} | |
}, | |
res = socket.recv_from(&mut buf) => { | |
let (num_read, from) = res?; | |
for binding in &mut bindings { | |
binding.handle_input(from, &buf[..num_read], Instant::now()); | |
} | |
} | |
} | |
for binding in &bindings { | |
timer.reset_to(binding.poll_timeout()); | |
} | |
for binding in &bindings { | |
if let Some(address) = binding.public_address() { | |
println!( | |
"Our public IP from server {:>19} is: {address}", | |
binding.server | |
); | |
} | |
} | |
} | |
} | |
#[derive(Default)] | |
struct Timer { | |
inner: Option<Pin<Box<tokio::time::Sleep>>>, | |
waker: Option<Waker>, | |
} | |
impl Timer { | |
fn reset_to(&mut self, next: Option<Instant>) { | |
let next = match next { | |
Some(next) => next, | |
None => { | |
self.inner = None; | |
return; | |
} | |
}; | |
match self.inner.as_mut() { | |
Some(timer) => { | |
if timer.deadline() <= tokio::time::Instant::from_std(next) { | |
return; | |
} | |
timer.as_mut().reset(next.into()) | |
} | |
None => { | |
self.inner = Some(Box::pin(tokio::time::sleep_until(next.into()))); | |
if let Some(waker) = self.waker.take() { | |
waker.wake() | |
} | |
} | |
} | |
} | |
} | |
impl Future for Timer { | |
type Output = Option<Instant>; | |
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { | |
let mut this = self.as_mut(); | |
let Some(timer) = this.inner.as_mut() else { | |
self.waker = Some(cx.waker().clone()); | |
return Poll::Ready(None); | |
}; | |
ready!(timer.as_mut().poll_unpin(cx)); | |
Poll::Ready(Some(timer.as_ref().deadline().into())) | |
} | |
} | |
struct StunBinding { | |
server: SocketAddr, | |
state: State, | |
buffered_transmits: VecDeque<Transmit>, | |
} | |
impl StunBinding { | |
fn new(server: SocketAddr) -> Self { | |
Self { | |
server, | |
state: State::Sent, | |
buffered_transmits: VecDeque::from([Transmit { | |
dst: server, | |
payload: make_binding_request(), | |
}]), | |
} | |
} | |
fn handle_input(&mut self, from: SocketAddr, packet: &[u8], now: Instant) { | |
if from != self.server { | |
return; | |
} | |
let address = parse_binding_response(packet); | |
self.state = State::Received { address, at: now }; | |
} | |
fn poll_transmit(&mut self) -> Option<Transmit> { | |
self.buffered_transmits.pop_front() | |
} | |
/// Notifies `StunBinding` that time has advanced to `now`. | |
fn handle_timeout(&mut self, now: Instant) { | |
let last_received_at = match self.state { | |
State::Sent => return, | |
State::Received { at, .. } => at, | |
}; | |
if now.duration_since(last_received_at) < Duration::from_secs(5) { | |
return; | |
} | |
self.buffered_transmits.push_front(Transmit { | |
dst: self.server, | |
payload: make_binding_request(), | |
}); | |
self.state = State::Sent; | |
} | |
/// Returns the timestamp when we next expect `handle_timeout` to be called. | |
fn poll_timeout(&self) -> Option<Instant> { | |
match self.state { | |
State::Sent => None, | |
State::Received { at, .. } => Some(at + Duration::from_secs(5)), | |
} | |
} | |
fn public_address(&self) -> Option<SocketAddr> { | |
match self.state { | |
State::Sent => None, | |
State::Received { address, .. } => Some(address), | |
} | |
} | |
} | |
enum State { | |
Sent, | |
Received { address: SocketAddr, at: Instant }, | |
} | |
struct Transmit { | |
dst: SocketAddr, | |
payload: Vec<u8>, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment