diff --git a/Cargo.lock b/Cargo.lock index e9c60e6..bcf135f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -127,6 +127,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "socket2", "tokio", "tokio-util", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 60afaaa..60cc0c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ tokio-util = { version = "0.7.1", features = ["codec"] } tracing = "0.1.32" tracing-subscriber = "0.3.18" uuid = { version = "1.2.1", features = ["serde", "v4"] } +socket2 = { version = "0.4.9" } [dev-dependencies] lazy_static = "1.4.0" diff --git a/src/client.rs b/src/client.rs index 2c21c16..1edeeff 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,5 +1,6 @@ //! Client implementation for the `bore` service. +use std::io::Read; use std::sync::Arc; use anyhow::{bail, Context, Result}; @@ -10,6 +11,7 @@ use uuid::Uuid; use crate::auth::Authenticator; use crate::shared::{ proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT, + BORE_KEEPINTERVAL, tcp_keepalive, parse_envvar_u64, }; /// State structure for the client. @@ -20,9 +22,12 @@ pub struct Client { /// Destination address of the server. to: String, - // Local host that is forwarded. + /// Local host that is forwarded. local_host: String, + /// Local host identity string. + host_id: String, + /// Local port that is forwarded. local_port: u16, @@ -33,24 +38,68 @@ pub struct Client { auth: Option, } +fn random_hostid(idlen: usize) -> String { + if idlen <= 4 { + if idlen == 0 { "".to_string() } else { std::iter::repeat_with(fastrand::alphanumeric).take(idlen).collect() } + } else { + let mut newid = String::with_capacity(idlen + 1); + let idlen = idlen - 4; + newid.push_str("TMP_"); + let rid: String = std::iter::repeat_with(fastrand::alphanumeric).take(idlen).collect(); + newid.push_str(&rid); + newid + } +} + +fn read_hostid_fromfile(idlen: usize) -> String { + // Get the text file containing `BORE_HOSTID + let idfile: std::ffi::OsString = match std::env::var_os("BORE_IDFILE") { + Some(fpath) => fpath, + None => std::ffi::OsString::from("/tmp/bore_hostid.txt"), + }; + + let hfile = std::fs::OpenOptions::new().read(true) + .write(false).create(false).open(&idfile); + if hfile.is_err() { + return random_hostid(idlen); + } + + let mut hfile = hfile.unwrap(); + let mut idbuf = vec![0u8; idlen]; + let rlen = hfile.read(&mut idbuf[..]).unwrap_or(0); + if rlen == 0 { + return random_hostid(idlen); + } + + let idstr = String::from_utf8_lossy(&idbuf[..rlen]); + let hostid: &str = idstr.trim(); + if hostid.is_empty() { random_hostid(idlen) } else { hostid.to_string() } +} + impl Client { /// Create a new client. pub async fn new( local_host: &str, local_port: u16, + id_str: &str, to: &str, port: u16, secret: Option<&str>, ) -> Result { - let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?); + let kval = parse_envvar_u64(BORE_KEEPINTERVAL, 120); + let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT, kval).await?); let auth = secret.map(Authenticator::new); if let Some(auth) = &auth { auth.client_handshake(&mut stream).await?; } - stream.send(ClientMessage::Hello(port)).await?; + // Determine host ID for remote bore server + let hostid = if id_str.is_empty() { read_hostid_fromfile(16) } else { id_str.to_string() }; + info!(hostid, "Using client IDString"); + + stream.send(ClientMessage::Hello(port, hostid.clone())).await?; let remote_port = match stream.recv_timeout().await? { - Some(ServerMessage::Hello(remote_port)) => remote_port, + Some(ServerMessage::Hello(remote_port, _)) => remote_port, Some(ServerMessage::Error(message)) => bail!("server error: {message}"), Some(ServerMessage::Challenge(_)) => { bail!("server requires authentication, but no client secret was provided"); @@ -65,6 +114,7 @@ impl Client { conn: Some(stream), to: to.to_string(), local_host: local_host.to_string(), + host_id: hostid, local_port, remote_port, auth, @@ -82,17 +132,18 @@ impl Client { let this = Arc::new(self); loop { match conn.recv().await? { - Some(ServerMessage::Hello(_)) => warn!("unexpected hello"), + Some(ServerMessage::Hello(_, _)) => warn!("unexpected hello"), Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"), Some(ServerMessage::Heartbeat) => (), Some(ServerMessage::Connection(id)) => { let this = Arc::clone(&this); + let hostid: String = this.host_id.clone(); tokio::spawn( async move { - info!("new connection"); + info!(hostid, "new connection"); match this.handle_connection(id).await { - Ok(_) => info!("connection exited"), - Err(err) => warn!(%err, "connection exited with error"), + Ok(_) => info!(hostid, "connection exited"), + Err(err) => warn!(hostid, %err, "connection exited with error"), } } .instrument(info_span!("proxy", %id)), @@ -105,13 +156,14 @@ impl Client { } async fn handle_connection(&self, id: Uuid) -> Result<()> { + let kval = parse_envvar_u64("BORE_KEEPINTERVAL", 120); let mut remote_conn = - Delimited::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?); + Delimited::new(connect_with_timeout(&self.to[..], CONTROL_PORT, kval).await?); if let Some(auth) = &self.auth { auth.client_handshake(&mut remote_conn).await?; } remote_conn.send(ClientMessage::Accept(id)).await?; - let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?; + let mut local_conn = connect_with_timeout(&self.local_host, self.local_port, kval).await?; let parts = remote_conn.into_parts(); debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty"); local_conn.write_all(&parts.read_buf).await?; // mostly of the cases, this will be empty @@ -120,9 +172,9 @@ impl Client { } } -async fn connect_with_timeout(to: &str, port: u16) -> Result { +async fn connect_with_timeout(to: &str, port: u16, keepival: u64) -> Result { match timeout(NETWORK_TIMEOUT, TcpStream::connect((to, port))).await { - Ok(res) => res, + Ok(res) => if res.is_ok() { Ok(tcp_keepalive(res.unwrap(), 3, keepival)) } else { res }, Err(err) => Err(err.into()), } .with_context(|| format!("could not connect to {to}:{port}")) diff --git a/src/main.rs b/src/main.rs index 71429c4..d808017 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,6 +23,10 @@ enum Command { #[clap(short, long, value_name = "HOST", default_value = "localhost")] local_host: String, + /// The host ID-String for remote server + #[clap(short, long, env = "BORE_HOSTID", default_value = "")] + id_string: String, + /// Address of the remote server to expose local ports to. #[clap(short, long, env = "BORE_SERVER")] to: String, @@ -66,11 +70,12 @@ async fn run(command: Command) -> Result<()> { Command::Local { local_host, local_port, + id_string, to, port, secret, } => { - let client = Client::new(&local_host, local_port, &to, port, secret.as_deref()).await?; + let client = Client::new(&local_host, local_port, &id_string, &to, port, secret.as_deref()).await?; client.listen().await?; } Command::Server { diff --git a/src/server.rs b/src/server.rs index 3c38988..414ec56 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,17 +2,38 @@ use std::net::{IpAddr, Ipv4Addr}; use std::{io, ops::RangeInclusive, sync::Arc, time::Duration}; +use std::collections::HashMap; use anyhow::Result; use dashmap::DashMap; use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream}; -use tokio::time::{sleep, timeout}; +use tokio::sync::Mutex; use tracing::{info, info_span, warn, Instrument}; use uuid::Uuid; use crate::auth::Authenticator; -use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT}; +use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, + BORE_KEEPINTERVAL, tcp_keepalive, parse_envvar_u64, +}; + +/// Client information structure +struct ClientInfo { + /// Port number previously used for the `hostid. + port_no: u16, + + /// Whether the client is online. + online: bool, + + /// UTC epoch time when the client connects or disconnects. + last_dance: u64, + + /// Number of times the client disconnects. + num_discon: u64, + + /// Oneshot channel to inform previous task to terminate. + cli_exit: Option>, +} /// State structure for the server. pub struct Server { @@ -30,6 +51,52 @@ pub struct Server { /// IP address where tunnels will listen on. bind_tunnels: IpAddr, + + /// HashMap-ped client information + clients: Arc>>, +} + +impl ClientInfo { + fn new(pno: u16, online: bool) -> Self { + ClientInfo { + port_no: pno, + online, + last_dance: ClientInfo::dance_utc(), + num_discon: 0u64, + cli_exit: None, + } + } + + fn dance_utc() -> u64 { + match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { + Ok(utc) => utc.as_secs(), + Err(_) => 0u64, + } + } + + /// This is an rudimentary implementation. When `bore server is handling + /// an enormous number of remote clients, leaking clients information + /// with this simple function is not recommended, as it might incur serious + /// performance penalty due to holding the hashmap `Mutex for quite a long time. + async fn leak_clients_info(mut tcp: TcpStream, clients: Arc>>) -> usize { + let mut cnum: usize = 0; + let ctable = clients.lock().await; + for (hostid, cinfo) in ctable.iter() { + let oneline = format!("client[{}] => hostid: {}, online: {}, portno: {}, last_dance: {}, discon: {}\n", + cnum, hostid, cinfo.online, cinfo.port_no, cinfo.last_dance, cinfo.num_discon); + if let Err(err) = tcp.write(oneline.as_bytes()).await { + warn!(%err, "Failed to leak clients information"); + break; + } + cnum += 1; + } + // drop the mutex lock explicitly + drop(ctable); + tokio::time::sleep(std::time::Duration::from_millis(660)).await; + let _ = tcp.shutdown().await; + drop(tcp); + cnum + } } impl Server { @@ -42,6 +109,7 @@ impl Server { auth: secret.map(Authenticator::new), bind_addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), bind_tunnels: IpAddr::V4(Ipv4Addr::UNSPECIFIED), + clients: Arc::new(Mutex::new(HashMap::new())), } } @@ -61,20 +129,41 @@ impl Server { let listener = TcpListener::bind((this.bind_addr, CONTROL_PORT)).await?; info!(addr = ?this.bind_addr, "server listening"); + // Create a listening socket for the query of existing clients. + let insider = TcpListener::bind("127.0.0.1:10088").await?; + + // TCP connection keep-alive interval. + let kval = parse_envvar_u64(BORE_KEEPINTERVAL, 55); + + // Heartbeat interval in seconds. For large number of devices, + // the frequency of heartbeat from server should be decreased. + let hbeat = parse_envvar_u64("BORE_HEARTBEAT_INTERVAL", 180); + let hbeat = if hbeat < 15 { 15u64 } else { hbeat }; + loop { - let (stream, addr) = listener.accept().await?; - let this = Arc::clone(&this); - tokio::spawn( - async move { - info!("incoming connection"); - if let Err(err) = this.handle_connection(stream).await { - warn!(%err, "connection exited with error"); - } else { - info!("connection exited"); + tokio::select! { + Ok((stream, addr)) = listener.accept() => { + let this = Arc::clone(&this); + tokio::spawn(async move { + info!("incoming connection"); + if let Err(err) = this.handle_connection(stream, kval, hbeat).await { + warn!(%err, "connection exited with error"); + } else { + info!("connection exited"); + } } - } - .instrument(info_span!("control", ?addr)), - ); + .instrument(info_span!("control", ?addr)), + ); + }, + Ok((spy, addr)) = insider.accept() => { + info!(%addr, "clients information leakage"); + let clients = Arc::clone(&this.clients); + tokio::spawn(async move { + let spy = tcp_keepalive(spy, 2, 3); + let _ = ClientInfo::leak_clients_info(spy, clients).await; + }); + }, + } } } @@ -115,7 +204,63 @@ impl Server { } } - async fn handle_connection(&self, stream: TcpStream) -> Result<()> { + async fn find_client_port(&self, hostid: &str) -> u16 { + let mut pno = 0u16; + let mut waitc = false; + // Lock client hashmap table: + let mut ctable = self.clients.lock().await; + if let Some(client) = ctable.get_mut(hostid) { + pno = client.port_no; + if let Some(cexit) = client.cli_exit.take() { + waitc = cexit.send(pno).is_ok(); + } + } + + // release mutex lock as quickly as we can + drop(ctable); + if waitc { + // wait another tokio task occupying `pno to exit + tokio::time::sleep(std::time::Duration::from_millis(210)).await; + } + pno + } + + // insert the hostid into `self.clients hashmap + async fn update_client_port(&self, hostid: &str, pno: u16, online: bool) + -> Option> { + let hostid = hostid.to_string(); + // Lock client hashmap table: + let mut ctable = self.clients.lock().await; + if let Some(oldcli) = ctable.get_mut(&hostid) { + // Do not update `port_no when a lingering task has figured + // out that a previously established client has closed connection. + if online || pno == oldcli.port_no { + oldcli.port_no = pno; + oldcli.online = online; + oldcli.last_dance = ClientInfo::dance_utc(); + } + + if online { + let (tx, rx) = tokio::sync::oneshot::channel::(); + oldcli.cli_exit = Some(tx); + return Some(rx); + } + oldcli.num_discon += 1; + } else { + let mut newcli = ClientInfo::new(pno, online); + if online { + let (tx, rx) = tokio::sync::oneshot::channel::(); + newcli.cli_exit = Some(tx); + ctable.insert(hostid, newcli); + return Some(rx); + } + ctable.insert(hostid, newcli); + } + None + } + + async fn handle_connection(&self, stream: TcpStream, keepival: u64, bhval: u64) -> Result<()> { + let stream = tcp_keepalive(stream, 3, keepival); let mut stream = Delimited::new(stream); if let Some(auth) = &self.auth { if let Err(err) = auth.server_handshake(&mut stream).await { @@ -130,41 +275,85 @@ impl Server { warn!("unexpected authenticate"); Ok(()) } - Some(ClientMessage::Hello(port)) => { - let listener = match self.create_listener(port).await { + Some(ClientMessage::Hello(port, hostid)) => { + if hostid.is_empty() { + // Disallow empty host-ID + return Ok(()); + } + // Try to reuse previously used port for specific `hostid + let pre = self.find_client_port(&hostid).await; + let pno = if port != 0 { port } else { pre }; + let listener = match self.create_listener(pno).await { Ok(listener) => listener, Err(err) => { - stream.send(ServerMessage::Error(err.into())).await?; - return Ok(()); + // if previous listener uses port zero, just return an error to client. + if pno == 0 { + stream.send(ServerMessage::Error(err.into())).await?; + return Ok(()); + } + // Try again with port number zero, as previously used port might be occupied: + match self.create_listener(0).await { + Ok(listener) => listener, + Err(err) => { + stream.send(ServerMessage::Error(err.into())).await?; + return Ok(()); + } + } } }; let host = listener.local_addr()?.ip(); let port = listener.local_addr()?.port(); info!(?host, ?port, "new client"); - stream.send(ServerMessage::Hello(port)).await?; + stream.send(ServerMessage::Hello(port, hostid.clone())).await?; + + // Create an timer for sending heart-beat messages + let mut hbt_it = tokio::time::interval(std::time::Duration::from_secs(bhval)); + hbt_it.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + // Get oneshot receiver handle, informing that following loop should terminate + let mut cli_rx = self.update_client_port(&hostid, port, true).await.unwrap(); loop { - if stream.send(ServerMessage::Heartbeat).await.is_err() { - // Assume that the TCP connection has been dropped. - return Ok(()); - } - const TIMEOUT: Duration = Duration::from_millis(500); - if let Ok(result) = timeout(TIMEOUT, listener.accept()).await { - let (stream2, addr) = result?; - info!(?addr, ?port, "new connection"); - - let id = Uuid::new_v4(); - let conns = Arc::clone(&self.conns); - - conns.insert(id, stream2); - tokio::spawn(async move { - // Remove stale entries to avoid memory leaks. - sleep(Duration::from_secs(10)).await; - if conns.remove(&id).is_some() { - warn!(%id, "removed stale connection"); + tokio::select! { + _ = hbt_it.tick() => { + if stream.send(ServerMessage::Heartbeat).await.is_err() { + // Assume that the TCP connection has been dropped. + let _ = self.update_client_port(&hostid, port, false).await; + return Ok(()); + } + }, + result = listener.accept() => { + if let Err(err) = result { + let _ = self.update_client_port(&hostid, port, false).await; + warn!(%err, "failed to parse incoming proxy request."); + return Err(err.into()); + } + let (stream2, addr) = result.unwrap(); + let stream2 = tcp_keepalive(stream2, 3, keepival); + info!(?addr, ?port, "new connection"); + + let id = Uuid::new_v4(); + let conns = Arc::clone(&self.conns); + + conns.insert(id, stream2); + tokio::spawn(async move { + // Remove stale entries to avoid memory leaks. + tokio::time::sleep(Duration::from_secs(10)).await; + if conns.remove(&id).is_some() { + warn!(%id, "removed stale connection"); + } + }); + if let Err(err) = stream.send(ServerMessage::Connection(id)).await { + let _ = self.update_client_port(&hostid, port, false).await; + return Err(err); } - }); - stream.send(ServerMessage::Connection(id)).await?; + }, + cexit = &mut cli_rx => { + drop(listener); // release port-number occupied by the listener + let _ = self.update_client_port(&hostid, port, false).await; + let forced = cexit.is_ok(); + warn!(hostid, forced, "client has been dropped"); + return Ok(()); + }, } } } diff --git a/src/shared.rs b/src/shared.rs index 10b1bc8..d362bec 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -8,7 +8,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::io::{self, AsyncRead, AsyncWrite}; use tokio::time::timeout; use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts}; -use tracing::trace; +use tracing::{trace, warn}; use uuid::Uuid; /// TCP port used for control connections with the server. @@ -20,6 +20,44 @@ pub const MAX_FRAME_LENGTH: usize = 256; /// Timeout for network connections and initial protocol messages. pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3); +/// Environment variable name for TCP keep-alive settings +pub const BORE_KEEPINTERVAL: &str = "BORE_KEEPINTERVAL"; + +/// Get and parse an environment variable as u64 +pub fn parse_envvar_u64(name: &str, dftval: u64) -> u64 { + match std::env::var(name) { + Ok(val) => { + if let Ok(enval) = str::parse::(&val) { enval } else { dftval } + }, + Err(_) => dftval, + } +} + +/// Change default TCP KEEPALIVE settings for a given TcpStream +pub fn tcp_keepalive(tcps: tokio::net::TcpStream, + retries: u32, keep_ival: u64) -> tokio::net::TcpStream { + // Reference it as an socket2 object + let tcpr = socket2::SockRef::from(&tcps); + + // enable or disable TCP keepalive + let errn = if retries == 0 || keep_ival == 0 { + tcpr.set_keepalive(false) + } else { + let kaopt = socket2::TcpKeepalive::new() + .with_retries(retries) + .with_time(std::time::Duration::from_secs(keep_ival)) + .with_interval(std::time::Duration::from_secs(keep_ival)); + let _ = tcpr.set_keepalive(true); + tcpr.set_tcp_keepalive(&kaopt) + }; + + if let Err(err) = errn { + warn!(%err, "failed to enable/disable TCP keepalive."); + } + let _ = tcpr.set_nonblocking(true); + tcps +} + /// A message from the client on the control connection. #[derive(Debug, Serialize, Deserialize)] pub enum ClientMessage { @@ -27,7 +65,7 @@ pub enum ClientMessage { Authenticate(String), /// Initial client message specifying a port to forward. - Hello(u16), + Hello(u16, String), /// Accepts an incoming TCP connection, using this stream as a proxy. Accept(Uuid), @@ -40,7 +78,7 @@ pub enum ServerMessage { Challenge(Uuid), /// Response to a client's initial message, with actual public port. - Hello(u16), + Hello(u16, String), /// No-op used to test if the client is still reachable. Heartbeat,