From cf3510918a4c9e60918e75667561331f0ad9e1c6 Mon Sep 17 00:00:00 2001 From: Jake McGinty Date: Tue, 14 Sep 2021 15:48:27 +0900 Subject: [PATCH] server: report local candidates for peers to connect (#151) Before, only clients would report local addresses for NAT traversal. Servers should too! This will be helpful in common situations when the server is run inside the same LAN as other peers, and there's no NAT hairpinning enabled (or possible) on the router. closes #146 --- client/src/main.rs | 7 +++---- server/src/api/user.rs | 2 ++ server/src/main.rs | 27 ++++++++++++++++++++++++--- shared/src/lib.rs | 28 ++++++++++++++++++++++++++++ shared/src/netlink.rs | 9 ++++----- shared/src/wg.rs | 25 ------------------------- 6 files changed, 61 insertions(+), 37 deletions(-) diff --git a/client/src/main.rs b/client/src/main.rs index 15f4880..2d2884a 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -4,6 +4,7 @@ use dialoguer::{Confirm, Input}; use hostsfile::HostsBuilder; use indoc::eprintdoc; use shared::{ + get_local_addrs, interface_config::InterfaceConfig, prompts, wg::{DeviceExt, PeerInfoExt}, @@ -448,7 +449,7 @@ fn redeem_invite( target_conf.to_string_lossy().yellow() ); - log::info!("Changing keys and waiting for server's WireGuard interface to transition.",); + log::info!("Changing keys and waiting 5s for server's WireGuard interface to transition.",); DeviceUpdate::new() .set_private_key(keypair.private) .apply(iface, network.backend) @@ -550,10 +551,8 @@ fn fetch( store.update_peers(&peers)?; store.write().with_str(interface.to_string())?; - let candidates = wg::get_local_addrs()? - .into_iter() + let candidates: Vec = get_local_addrs()? .map(|addr| SocketAddr::from((addr, device.listen_port.unwrap_or(51820))).into()) - .take(10) .collect::>(); log::info!( "reporting {} interface address{} as NAT traversal candidates...", diff --git a/server/src/api/user.rs b/server/src/api/user.rs index ffedd76..99509b8 100644 --- a/server/src/api/user.rs +++ b/server/src/api/user.rs @@ -106,6 +106,8 @@ mod handlers { // This might be avoidable if we were able to run code after we were certain the response // had flushed over the TCP socket, but that isn't easily accessible from this high-level // web framework. + // + // Related: https://github.com/hyperium/hyper/issues/2181 tokio::task::spawn(async move { tokio::time::sleep(*REDEEM_TRANSITION_WAIT).await; log::info!( diff --git a/server/src/main.rs b/server/src/main.rs index 065d4a3..4ae602b 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -8,8 +8,8 @@ use parking_lot::{Mutex, RwLock}; use rusqlite::Connection; use serde::{Deserialize, Serialize}; use shared::{ - AddCidrOpts, AddPeerOpts, DeleteCidrOpts, IoErrorContext, NetworkOpt, RenamePeerOpts, - INNERNET_PUBKEY_HEADER, + get_local_addrs, AddCidrOpts, AddPeerOpts, DeleteCidrOpts, Endpoint, IoErrorContext, + NetworkOpt, PeerContents, RenamePeerOpts, INNERNET_PUBKEY_HEADER, }; use std::{ collections::{HashMap, VecDeque}, @@ -479,7 +479,7 @@ async fn serve( log::debug!("opening database connection..."); let conn = open_database_connection(&interface, conf)?; - let peers = DatabasePeer::list(&conn)?; + let mut peers = DatabasePeer::list(&conn)?; log::debug!("peers listed..."); let peer_configs = peers .iter() @@ -502,6 +502,27 @@ async fn serve( log::info!("{} peers added to wireguard interface.", peers.len()); + let candidates: Vec = get_local_addrs()? + .map(|addr| SocketAddr::from((addr, config.listen_port)).into()) + .collect(); + let num_candidates = candidates.len(); + let myself = peers + .iter_mut() + .find(|peer| peer.ip == config.address) + .expect("Couldn't find server peer in peer list."); + myself.update( + &conn, + PeerContents { + candidates, + ..myself.contents.clone() + }, + )?; + + log::info!( + "{} local candidates added to server peer config.", + num_candidates + ); + let public_key = wgctrl::Key::from_base64(&config.private_key)?.generate_public(); let db = Arc::new(Mutex::new(conn)); let endpoints = spawn_endpoint_refresher(interface, network); diff --git a/shared/src/lib.rs b/shared/src/lib.rs index b399ab6..05cda3c 100644 --- a/shared/src/lib.rs +++ b/shared/src/lib.rs @@ -79,3 +79,31 @@ pub fn chmod(file: &File, new_mode: u32) -> Result { Ok(updated) } + +#[cfg(target_os = "macos")] +pub fn _get_local_addrs() -> Result, io::Error> { + use nix::{net::if_::InterfaceFlags, sys::socket::SockAddr}; + + let addrs = nix::ifaddrs::getifaddrs()? + .filter(|addr| { + addr.flags.contains(InterfaceFlags::IFF_UP) + && !addr.flags.intersects( + InterfaceFlags::IFF_LOOPBACK + | InterfaceFlags::IFF_POINTOPOINT + | InterfaceFlags::IFF_PROMISC, + ) + }) + .filter_map(|addr| match addr.address { + Some(SockAddr::Inet(addr)) if addr.to_std().is_ipv4() => Some(addr.to_std().ip()), + _ => None, + }); + + Ok(addrs) +} + +#[cfg(target_os = "linux")] +pub use netlink::get_local_addrs as _get_local_addrs; + +pub fn get_local_addrs() -> Result, io::Error> { + Ok(_get_local_addrs()?.take(10)) +} diff --git a/shared/src/netlink.rs b/shared/src/netlink.rs index 95a268b..d85d011 100644 --- a/shared/src/netlink.rs +++ b/shared/src/netlink.rs @@ -180,7 +180,7 @@ fn get_links() -> Result, io::Error> { Ok(links) } -pub fn get_local_addrs() -> Result, io::Error> { +pub fn get_local_addrs() -> Result, io::Error> { let links = get_links()?; let addr_responses = netlink_call( RtnlMessage::GetAddress(AddressMessage::default()), @@ -203,7 +203,7 @@ pub fn get_local_addrs() -> Result, io::Error> { None }) // Only select addresses for helpful links - .filter(|nlas| nlas.iter().any(|nla| matches!(nla, address::nlas::Nla::Label(label) if links.contains(label)))) + .filter(move |nlas| nlas.iter().any(|nla| matches!(nla, address::nlas::Nla::Label(label) if links.contains(label)))) .filter_map(|nlas| nlas.iter().find_map(|nla| match nla { address::nlas::Nla::Address(name) if name.len() == 4 => { let mut addr = [0u8; 4]; @@ -216,8 +216,7 @@ pub fn get_local_addrs() -> Result, io::Error> { Some(IpAddr::V6(addr.into())) }, _ => None, - })) - .collect::>(); + })); Ok(addrs) } @@ -228,6 +227,6 @@ mod tests { #[test] fn test_local_addrs() { let addrs = get_local_addrs().unwrap(); - println!("{:?}", addrs); + println!("{:?}", addrs.collect::>()); } } diff --git a/shared/src/wg.rs b/shared/src/wg.rs index 5d4dca4..e44367e 100644 --- a/shared/src/wg.rs +++ b/shared/src/wg.rs @@ -166,31 +166,6 @@ pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result Result, io::Error> { - use nix::{net::if_::InterfaceFlags, sys::socket::SockAddr}; - - let addrs = nix::ifaddrs::getifaddrs()? - .filter(|addr| { - addr.flags.contains(InterfaceFlags::IFF_UP) - && !addr.flags.intersects( - InterfaceFlags::IFF_LOOPBACK - | InterfaceFlags::IFF_POINTOPOINT - | InterfaceFlags::IFF_PROMISC, - ) - }) - .filter_map(|addr| match addr.address { - Some(SockAddr::Inet(addr)) if addr.to_std().is_ipv4() => Some(addr.to_std().ip()), - _ => None, - }) - .collect::>(); - - Ok(addrs) -} - -#[cfg(target_os = "linux")] -pub use super::netlink::get_local_addrs; - pub trait DeviceExt { /// Diff the output of a wgctrl device with a list of server-reported peers. fn diff<'a>(&'a self, peers: &'a [Peer]) -> Vec>;