Skip to content

Instantly share code, notes, and snippets.

@keepsimple1
Created July 15, 2022 18:38
Show Gist options
  • Save keepsimple1/05eba6c1b01e73e4f8abf2c9f88e531a to your computer and use it in GitHub Desktop.
Save keepsimple1/05eba6c1b01e73e4f8abf2c9f88e531a to your computer and use it in GitHub Desktop.
A patch to send mDNS outgoing messages based on interfaces and LAN segment
diff --git a/Cargo.toml b/Cargo.toml
index db11d7d..4f2febc 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -15,8 +15,8 @@ async = ["flume/async"]
default = ["async"]
[dependencies]
-flume = { version = "0.10", default-features = false } # channel between threads
-if-addrs = "0.7" # get local IP addresses
-log = "0.4.14" # logging
-polling = "2.1.0" # select/poll sockets
-socket2 = { version = "0.4", features = ["all"] } # socket APIs
+flume = { version = "0.10", default-features = false } # channel between threads
+if-addrs = "0.7" # get local IP addresses
+log = "0.4.14" # logging
+polling = "2.1.0" # select/poll sockets
+socket2 = { version = "0.3.19", features = ["reuseport"] } # socket APIs
diff --git a/src/service_daemon.rs b/src/service_daemon.rs
index 7001dd3..f73054b 100644
--- a/src/service_daemon.rs
+++ b/src/service_daemon.rs
@@ -39,6 +39,7 @@ use crate::{
Receiver,
};
use flume::{bounded, Sender, TrySendError};
+use if_addrs::{IfAddr, Ifv4Addr};
use log::{debug, error};
use polling::Poller;
use socket2::{SockAddr, Socket};
@@ -46,8 +47,7 @@ use std::{
cmp,
collections::HashMap,
fmt,
- io::Read,
- net::{IpAddr, Ipv4Addr, SocketAddrV4},
+ net::{Ipv4Addr, SocketAddrV4},
str, thread,
time::Duration,
vec,
@@ -346,7 +346,9 @@ impl ServiceDaemon {
debug!("announce service: {}", &fullname);
match zc.my_services.get(&fullname) {
Some(info) => {
- zc.broadcast_service(info);
+ for (intf, _) in zc.respond_sockets.iter() {
+ zc.broadcast_service_on_intf(info, intf);
+ }
zc.increase_counter(Counter::RegisterResend, 1);
}
None => debug!("announce: cannot find such service {}", &fullname),
@@ -361,15 +363,18 @@ impl ServiceDaemon {
UnregisterStatus::NotFound
}
Some((_k, info)) => {
- let packet = zc.unregister_service(&info);
- zc.increase_counter(Counter::Unregister, 1);
- // repeat for one time just in case some peers miss the message
- if !repeating && !packet.is_empty() {
- let next_time = current_time_millis() + 120;
- zc.retransmissions.push(ReRun {
- next_time,
- command: Command::UnregisterResend(packet),
- });
+ let intf_list: Vec<Ifv4Addr> = zc.respond_sockets.keys().cloned().collect();
+ for intf in intf_list {
+ let packet = zc.unregister_service(&info, &intf);
+ zc.increase_counter(Counter::Unregister, 1);
+ // repeat for one time just in case some peers miss the message
+ if !repeating && !packet.is_empty() {
+ let next_time = current_time_millis() + 120;
+ zc.retransmissions.push(ReRun {
+ next_time,
+ command: Command::UnregisterResend(packet, intf),
+ });
+ }
}
UnregisterStatus::OK
}
@@ -379,9 +384,9 @@ impl ServiceDaemon {
}
}
- Command::UnregisterResend(packet) => {
+ Command::UnregisterResend(packet, intf) => {
debug!("Send a packet length of {}", packet.len());
- zc.send_packet(&packet[..], &zc.broadcast_addr);
+ zc.send_packet(&packet[..], &zc.broadcast_addr, &intf);
zc.increase_counter(Counter::UnregisterResend, 1);
}
@@ -425,7 +430,7 @@ impl ServiceDaemon {
/// Creates a new UDP socket to bind to `port` with REUSEPORT option.
/// `non_block` indicates whether to set O_NONBLOCK for the socket.
fn new_socket(ipv4: Ipv4Addr, port: u16, non_block: bool) -> Result<Socket> {
- let fd = Socket::new(socket2::Domain::IPV4, socket2::Type::DGRAM, None)
+ let fd = Socket::new(socket2::Domain::ipv4(), socket2::Type::dgram(), None)
.map_err(|e| e_fmt!("create socket failed: {}", e))?;
fd.set_reuse_address(true)
@@ -447,35 +452,6 @@ fn new_socket(ipv4: Ipv4Addr, port: u16, non_block: bool) -> Result<Socket> {
Ok(fd)
}
-/// Returns the list of IPv4 addrs assigned to this host, excluding loopback or multicast addrs.
-/// If something goes wrong, returns an empty list and logs an error message.
-fn my_ipv4_addrs() -> Vec<Ipv4Addr> {
- let mut result = Vec::new();
-
- match if_addrs::get_if_addrs() {
- Ok(addr_list) => {
- for addr in addr_list {
- if addr.is_loopback() {
- continue;
- }
- match addr.ip() {
- IpAddr::V4(v4_addr) => {
- if v4_addr.is_multicast() {
- continue;
- }
- result.push(v4_addr);
- }
- IpAddr::V6(_) => {}
- }
- }
- }
- Err(e) => error!("get_if_addrs: {}", e),
- }
- debug!("IPv4 addrs: {:?}", &result);
-
- result
-}
-
struct ReRun {
next_time: u64,
command: Command,
@@ -488,8 +464,7 @@ struct Zeroconf {
listen_socket: Socket,
/// Sockets for outgoing packets. One socket for each non-loopback assigned IPv4.
- /// NOTE: For now we only support multicast.
- respond_sockets: Vec<Socket>,
+ respond_sockets: HashMap<Ifv4Addr, Socket>,
/// Local registered services
my_services: HashMap<String, ServiceInfo>,
@@ -521,15 +496,31 @@ impl Zeroconf {
let group_addr = Ipv4Addr::new(224, 0, 0, 251);
+ // Get IPv4 interfaces.
+ let my_ifv4addrs: Vec<Ifv4Addr> = if_addrs::get_if_addrs()
+ .unwrap_or_default()
+ .into_iter()
+ .filter_map(|i| {
+ if i.is_loopback() {
+ None
+ } else {
+ match i.addr {
+ IfAddr::V4(ifv4) => Some(ifv4),
+ _ => None,
+ }
+ }
+ })
+ .collect();
+
// We create a socket for every outgoing IPv4 interface.
- let mut respond_sockets = Vec::new();
- for ipv4_addr in my_ipv4_addrs() {
+ let mut respond_sockets = HashMap::new();
+ for ifv4addr in my_ifv4addrs {
listen_socket
- .join_multicast_v4(&group_addr, &ipv4_addr)
- .map_err(|e| e_fmt!("join multicast group on addr {}: {}", &ipv4_addr, e))?;
+ .join_multicast_v4(&group_addr, &ifv4addr.ip)
+ .map_err(|e| e_fmt!("join multicast group on addr {}: {}", &ifv4addr.ip, e))?;
- let respond_socket = new_socket(ipv4_addr, udp_port, false)?;
- respond_sockets.push(respond_socket);
+ let respond_socket = new_socket(ifv4addr.ip, udp_port, false)?;
+ respond_sockets.insert(ifv4addr, respond_socket);
}
let broadcast_addr = SocketAddrV4::new(group_addr, MDNS_PORT).into();
@@ -562,7 +553,9 @@ impl Zeroconf {
return;
}
- self.broadcast_service(&info);
+ for (intf, _) in self.respond_sockets.iter() {
+ self.broadcast_service_on_intf(&info, intf);
+ }
// RFC 6762 section 8.3.
// ..The Multicast DNS responder MUST send at least two unsolicited
@@ -580,7 +573,7 @@ impl Zeroconf {
}
/// Send an unsolicited response for owned service
- fn broadcast_service(&self, info: &ServiceInfo) {
+ fn broadcast_service_on_intf(&self, info: &ServiceInfo, intf: &Ifv4Addr) {
debug!("broadcast service {}", info.get_fullname());
let mut out = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA);
out.add_answer_at_time(
@@ -632,6 +625,9 @@ impl Zeroconf {
);
for addr in info.get_addresses() {
+ if !self.valid_ipv4_on_intf(addr, intf) {
+ continue;
+ }
out.add_answer_at_time(
Box::new(DnsAddress::new(
info.get_hostname(),
@@ -644,10 +640,19 @@ impl Zeroconf {
);
}
- self.send(&out, &self.broadcast_addr);
+ self.send(&out, &self.broadcast_addr, intf);
+ }
+
+ fn valid_ipv4_on_intf(&self, addr: &Ipv4Addr, intf: &Ifv4Addr) -> bool {
+ let netmask = u32::from(intf.netmask);
+ let intf_u32 = u32::from(intf.ip);
+ let addr_u32 = u32::from(*addr);
+ let intf_net = intf_u32 & netmask;
+ let addr_net = addr_u32 & netmask;
+ addr_net == intf_net
}
- fn unregister_service(&self, info: &ServiceInfo) -> Vec<u8> {
+ fn unregister_service(&self, info: &ServiceInfo, intf: &Ifv4Addr) -> Vec<u8> {
let mut out = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA);
out.add_answer_at_time(
Box::new(DnsPointer::new(
@@ -698,6 +703,9 @@ impl Zeroconf {
);
for addr in info.get_addresses() {
+ if !self.valid_ipv4_on_intf(addr, intf) {
+ continue;
+ }
out.add_answer_at_time(
Box::new(DnsAddress::new(
info.get_hostname(),
@@ -710,7 +718,7 @@ impl Zeroconf {
);
}
- self.send(&out, &self.broadcast_addr)
+ self.send(&out, &self.broadcast_addr, intf)
}
/// Binds a channel `listener` to querying mDNS domain type `ty`.
@@ -721,7 +729,7 @@ impl Zeroconf {
}
/// Sends an outgoing packet, and returns the packet bytes.
- fn send(&self, out: &DnsOutgoing, addr: &SockAddr) -> Vec<u8> {
+ fn send(&self, out: &DnsOutgoing, addr: &SockAddr, intf: &Ifv4Addr) -> Vec<u8> {
let qtype = if out.is_query() { "query" } else { "response" };
debug!(
"Sending {} to {:?}: {} questions {} answers {} authorities {} additional",
@@ -738,16 +746,17 @@ impl Zeroconf {
return Vec::new();
}
- self.send_packet(&packet[..], addr);
+ self.send_packet(&packet[..], addr, intf);
packet
}
- fn send_packet(&self, packet: &[u8], addr: &SockAddr) {
- for s in self.respond_sockets.iter() {
- match s.send_to(packet, addr) {
+ fn send_packet(&self, packet: &[u8], addr: &SockAddr, intf: &Ifv4Addr) {
+ match self.respond_sockets.get(intf) {
+ Some(s) => match s.send_to(packet, addr) {
Ok(sz) => debug!("sent out {} bytes on socket {:?}", sz, s),
Err(e) => error!("send failed: {}", e),
- }
+ },
+ None => error!("cannot find socket for interface: {:?}", intf),
}
}
@@ -755,15 +764,17 @@ impl Zeroconf {
debug!("Sending multicast query for {}", name);
let mut out = DnsOutgoing::new(FLAGS_QR_QUERY);
out.add_question(name, qtype);
- self.send(&out, &self.broadcast_addr);
+ for (intf, _) in self.respond_sockets.iter() {
+ self.send(&out, &self.broadcast_addr, intf);
+ }
}
/// Returns false if failed to receive a packet,
/// otherwise returns true.
fn handle_read(&mut self) -> bool {
let mut buf = vec![0u8; MAX_MSG_ABSOLUTE];
- let sz = match self.listen_socket.read(&mut buf) {
- Ok(sz) => sz,
+ let (sz, src) = match self.listen_socket.recv_from(&mut buf) {
+ Ok((sz, src)) => (sz, src),
Err(e) => {
if e.kind() != std::io::ErrorKind::WouldBlock {
error!("listening socket read failed: {}", e);
@@ -777,7 +788,7 @@ impl Zeroconf {
match DnsIncoming::new(buf) {
Ok(msg) => {
if msg.is_query() {
- self.handle_query(msg);
+ self.handle_query(msg, &src);
} else if msg.is_response() {
self.handle_response(msg);
} else {
@@ -1024,7 +1035,23 @@ impl Zeroconf {
}
}
- fn handle_query(&mut self, msg: DnsIncoming) {
+ fn handle_query(&mut self, msg: DnsIncoming, src: &SockAddr) {
+ let src_ip = match src.as_inet() {
+ Some(addr) => *(addr.ip()),
+ None => return,
+ };
+ let mut intf_opt = None;
+ for (i, _) in self.respond_sockets.iter() {
+ if i.ip == src_ip {
+ intf_opt = Some(i);
+ break;
+ }
+ }
+ let intf = match intf_opt {
+ Some(i) => i,
+ None => return,
+ };
+
let mut out = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA);
// Special meta-query "_services._dns-sd._udp.<Domain>".
@@ -1130,7 +1157,7 @@ impl Zeroconf {
if !out.answers.is_empty() {
out.id = msg.id;
- self.send(&out, &self.broadcast_addr);
+ self.send(&out, &self.broadcast_addr, intf);
self.increase_counter(Counter::Respond, 1);
}
@@ -1179,7 +1206,7 @@ enum Command {
RegisterResend(String), // (fullname)
/// Resend unregister packet.
- UnregisterResend(Vec<u8>), // (packet content)
+ UnregisterResend(Vec<u8>, Ifv4Addr), // (packet content)
/// Stop browsing a service type
StopBrowse(String), // (ty_domain)
@@ -1359,22 +1386,3 @@ fn call_listener(
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::my_ipv4_addrs;
-
- #[test]
- fn test_my_ipv4_addrs() {
- let addrs = my_ipv4_addrs();
-
- // The test host should have at least one IPv4 addr.
- assert!(!addrs.is_empty());
-
- // Verify the address attributes.
- for addr in addrs {
- assert!(!addr.is_loopback());
- assert!(!addr.is_multicast());
- }
- }
-}
diff --git a/tests/mdns_test.rs b/tests/mdns_test.rs
index 7db9bc7..c012cf7 100644
--- a/tests/mdns_test.rs
+++ b/tests/mdns_test.rs
@@ -18,8 +18,8 @@ fn integration_success() {
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap();
let instance_name = now.as_micros().to_string(); // Create a unique name.
- let host_ipv4 = "192.168.1.12";
- let host_name = "192.168.1.12.";
+ let host_ipv4 = "192.168.0.12";
+ let host_name = "192.168.0.12.";
let port = 5200;
let mut properties = HashMap::new();
properties.insert("property_1".to_string(), "test".to_string());
@@ -200,8 +200,8 @@ fn service_without_properties() {
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap();
let instance_name = now.as_micros().to_string(); // Create a unique name.
- let host_ipv4 = "192.168.1.13";
- let host_name = "192.168.1.13.";
+ let host_ipv4 = "192.168.0.13";
+ let host_name = "192.168.0.13.";
let port = 5201;
let my_service = ServiceInfo::new(ty_domain, &instance_name, host_name, host_ipv4, port, None)
.expect("valid service info");
@@ -246,8 +246,8 @@ fn subtype() {
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap();
let instance_name = now.as_micros().to_string(); // Create a unique name.
- let host_ipv4 = "192.168.1.13";
- let host_name = "192.168.1.13.";
+ let host_ipv4 = "192.168.0.13";
+ let host_name = "192.168.0.13.";
let port = 5201;
let my_service = ServiceInfo::new(
subtype_domain,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment