Skip to content

Commit b4331ab

Browse files
committed
feat: Add TCP socket support
Add TCP support for non-windows platforms: - Add helpers to parse TCP addresses - Add TCP transport and helpers for the async implementation Signed-off-by: Kostis Papazafeiropoulos <papazof@gmail.com>
1 parent 2244900 commit b4331ab

File tree

4 files changed

+164
-6
lines changed

4 files changed

+164
-6
lines changed

src/asynchronous/server.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ impl Server {
101101
Ok(self.add_listener(listener))
102102
}
103103

104+
#[cfg(unix)]
105+
/// # Safety
106+
/// The file descriptor must represent a unix listener.
107+
pub unsafe fn add_tcp_listener(self, fd: RawFd) -> Result<Server> {
108+
let listener = Listener::from_raw_tcp_listener_fd(fd)
109+
.map_err(err_to_others_err!(e, "from_raw_tcp_listener_fd error"))?;
110+
Ok(self.add_listener(listener))
111+
}
112+
104113
#[cfg(any(target_os = "linux", target_os = "android"))]
105114
/// # Safety
106115
/// The file descriptor must represent a vsock listener.

src/asynchronous/transport/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ macro_rules! io_other {
2222
#[cfg(unix)]
2323
mod unix;
2424

25+
#[cfg(unix)]
26+
mod tcp;
27+
2528
#[cfg(any(target_os = "linux", target_os = "android"))]
2629
mod vsock;
2730

@@ -43,6 +46,11 @@ impl Listener {
4346
return Self::bind_unix(addr);
4447
}
4548

49+
#[cfg(unix)]
50+
if let Some(addr) = addr.strip_prefix("tcp://") {
51+
return Self::bind_tcp(addr);
52+
}
53+
4654
#[cfg(any(target_os = "linux", target_os = "android"))]
4755
if let Some(addr) = addr.strip_prefix("vsock://") {
4856
return Self::bind_vsock(addr);
@@ -70,6 +78,11 @@ impl Socket {
7078
return Self::connect_unix(addr).await;
7179
}
7280

81+
#[cfg(unix)]
82+
if let Some(addr) = addr.strip_prefix("tcp://") {
83+
return Self::connect_tcp(addr).await;
84+
}
85+
7386
#[cfg(any(target_os = "linux", target_os = "android"))]
7487
if let Some(addr) = addr.strip_prefix("vsock://") {
7588
return Self::connect_vsock(addr).await;

src/asynchronous/transport/tcp.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
use std::convert::TryFrom;
2+
use std::io::{Error as IoError, Result as IoResult};
3+
use std::os::fd::{FromRawFd as _, RawFd};
4+
use std::net::{
5+
SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream,
6+
};
7+
8+
use async_stream::stream;
9+
use tokio::net::{TcpListener, TcpStream};
10+
11+
use super::{Listener, Socket};
12+
13+
impl Listener {
14+
pub fn bind_tcp(addr: impl AsRef<str>) -> IoResult<Self> {
15+
let addr = parse_tcp_addr(addr)?;
16+
let listener = StdTcpListener::bind(addr)?;
17+
Self::try_from(listener)
18+
}
19+
20+
/// # Safety
21+
/// The file descriptor must represent a tcp listener.
22+
pub unsafe fn from_raw_tcp_listener_fd(fd: std::os::fd::RawFd) -> IoResult<Self> {
23+
let listener = unsafe { StdTcpListener::from_raw_fd(fd) };
24+
Self::try_from(listener)
25+
}
26+
}
27+
28+
impl Socket {
29+
pub async fn connect_tcp(addr: impl AsRef<str>) -> IoResult<Self> {
30+
let addr = parse_tcp_addr(addr)?;
31+
let socket = StdTcpStream::connect(addr)?;
32+
Self::try_from(socket)
33+
}
34+
35+
/// # Safety
36+
/// The file descriptor must represent a tcp socket.
37+
pub unsafe fn from_raw_tcp_socket_fd(fd: RawFd) -> IoResult<Self> {
38+
let socket = unsafe { StdTcpStream::from_raw_fd(fd) };
39+
Self::try_from(socket)
40+
}
41+
}
42+
43+
impl From<TcpListener> for Listener {
44+
fn from(listener: TcpListener) -> Self {
45+
Self::new(stream! {
46+
loop {
47+
yield listener.accept().await.map(|(socket, _)| socket);
48+
}
49+
})
50+
}
51+
}
52+
53+
impl TryFrom<StdTcpListener> for Listener {
54+
type Error = IoError;
55+
fn try_from(listener: StdTcpListener) -> IoResult<Self> {
56+
listener.set_nonblocking(true)?;
57+
Ok(Self::from(TcpListener::from_std(listener)?))
58+
}
59+
}
60+
61+
impl From<TcpStream> for Socket {
62+
fn from(socket: TcpStream) -> Self {
63+
Self::new(socket)
64+
}
65+
}
66+
67+
impl TryFrom<StdTcpStream> for Socket {
68+
type Error = IoError;
69+
fn try_from(socket: StdTcpStream) -> IoResult<Self> {
70+
socket.set_nonblocking(true)?;
71+
Ok(Self::from(TcpStream::from_std(socket)?))
72+
}
73+
}
74+
75+
fn parse_tcp_addr(addr: impl AsRef<str>) -> IoResult<SocketAddr> {
76+
let addr = addr.as_ref();
77+
78+
addr.parse::<SocketAddr>()
79+
.map_err(|e| io_other!("Failed to parse TCP address '{}': {}", addr, e))
80+
}

src/common.rs

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
99
use nix::fcntl::{fcntl, FcntlArg, OFlag};
1010
use nix::sys::socket::*;
11-
use std::os::unix::io::RawFd;
11+
use std::str::FromStr;
12+
use std::{env, os::unix::io::RawFd};
1213

1314
use crate::error::{Error, Result};
1415

1516
#[derive(Debug, Clone, Copy, PartialEq)]
1617
pub(crate) enum Domain {
1718
Unix,
19+
Tcp,
1820
#[cfg(any(target_os = "linux", target_os = "android"))]
1921
Vsock,
2022
}
@@ -39,6 +41,10 @@ fn parse_sockaddr(addr: &str) -> Result<(Domain, &str)> {
3941
return Ok((Domain::Vsock, addr));
4042
}
4143

44+
if let Some(addr) = addr.strip_prefix("tcp://") {
45+
return Ok((Domain::Tcp, addr));
46+
}
47+
4248
Err(Error::Others(format!("Scheme {addr:?} is not supported")))
4349
}
4450

@@ -53,6 +59,10 @@ fn parse_sockaddr(addr: &str) -> Result<(Domain, &str)> {
5359
return Ok((Domain::Unix, addr));
5460
}
5561

62+
if let Some(addr) = addr.strip_prefix("tcp://") {
63+
return Ok((Domain::Tcp, addr));
64+
}
65+
5666
Err(Error::Others(format!("Scheme {addr:?} is not supported")))
5767
}
5868

@@ -83,8 +93,8 @@ fn make_addr(domain: Domain, sockaddr: &str) -> Result<UnixAddr> {
8393
UnixAddr::new(sockaddr).map_err(err_to_others_err!(e, ""))
8494
}
8595
}
86-
Domain::Vsock => Err(Error::Others(
87-
"function make_addr does not support create vsock socket".to_string(),
96+
Domain::Vsock | Domain::Tcp => Err(Error::Others(
97+
"function make_addr does not support create vsock/tcp socket".to_string(),
8898
)),
8999
}
90100
}
@@ -130,7 +140,7 @@ fn parse_vscok(addr: &str) -> Result<(u32, u32)> {
130140
fn make_socket(sockaddr: &str) -> Result<(RawFd, Domain, Box<dyn SockaddrLike>)> {
131141
let (domain, sockaddrv) = parse_sockaddr(sockaddr)?;
132142

133-
let get_sock_addr = |domain, sockaddr| -> Result<(RawFd, Box<dyn SockaddrLike>)> {
143+
let get_unix_addr = |domain, sockaddr| -> Result<(RawFd, Box<dyn SockaddrLike>)> {
134144
let fd = socket(AddressFamily::Unix, SockType::Stream, SOCK_CLOEXEC, None)
135145
.map_err(|e| Error::Socket(e.to_string()))?;
136146

@@ -141,9 +151,20 @@ fn make_socket(sockaddr: &str) -> Result<(RawFd, Domain, Box<dyn SockaddrLike>)>
141151
let sockaddr = make_addr(domain, sockaddr)?;
142152
Ok((fd, Box::new(sockaddr)))
143153
};
154+
let get_tcp_addr = |sockaddr: &str| -> Result<(RawFd, Box<dyn SockaddrLike>)> {
155+
let fd = socket(AddressFamily::Inet, SockType::Stream, SOCK_CLOEXEC, None)
156+
.map_err(|e| Error::Socket(e.to_string()))?;
157+
158+
#[cfg(target_os = "macos")]
159+
set_fd_close_exec(fd)?;
160+
let sockaddr = SockaddrIn::from_str(sockaddr).map_err(err_to_others_err!(e, ""))?;
161+
162+
Ok((fd, Box::new(sockaddr)))
163+
};
144164

145165
let (fd, sockaddr): (i32, Box<dyn SockaddrLike>) = match domain {
146-
Domain::Unix => get_sock_addr(domain, sockaddrv)?,
166+
Domain::Unix => get_unix_addr(domain, sockaddrv)?,
167+
Domain::Tcp => get_tcp_addr(sockaddrv)?,
147168
#[cfg(any(target_os = "linux", target_os = "android"))]
148169
Domain::Vsock => {
149170
let (cid, port) = parse_vscok(sockaddrv)?;
@@ -162,18 +183,41 @@ fn make_socket(sockaddr: &str) -> Result<(RawFd, Domain, Box<dyn SockaddrLike>)>
162183
Ok((fd, domain, sockaddr))
163184
}
164185

186+
fn set_socket_opts(fd: RawFd, domain: Domain, is_bind: bool) -> Result<()> {
187+
if domain != Domain::Tcp {
188+
return Ok(());
189+
}
190+
191+
if is_bind {
192+
setsockopt(fd, sockopt::ReusePort, &true)?;
193+
}
194+
195+
let tcp_nodelay_enabled = match env::var("TTRPC_TCP_NODELAY_ENABLED") {
196+
Ok(val) if val == "1" || val.eq_ignore_ascii_case("true") => true,
197+
Ok(val) if val == "0" || val.eq_ignore_ascii_case("false") => false,
198+
_ => false,
199+
};
200+
if tcp_nodelay_enabled {
201+
setsockopt(fd, sockopt::TcpNoDelay, &true)?;
202+
}
203+
204+
Ok(())
205+
}
206+
165207
pub(crate) fn do_bind(sockaddr: &str) -> Result<(RawFd, Domain)> {
166208
let (fd, domain, sockaddr) = make_socket(sockaddr)?;
167209

210+
set_socket_opts(fd, domain, true)?;
168211
bind(fd, sockaddr.as_ref()).map_err(err_to_others_err!(e, ""))?;
169212

170213
Ok((fd, domain))
171214
}
172215

173216
/// Creates a unix socket for client.
174217
pub(crate) unsafe fn client_connect(sockaddr: &str) -> Result<RawFd> {
175-
let (fd, _, sockaddr) = make_socket(sockaddr)?;
218+
let (fd, domain, sockaddr) = make_socket(sockaddr)?;
176219

220+
set_socket_opts(fd, domain, false)?;
177221
connect(fd, sockaddr.as_ref())?;
178222

179223
Ok(fd)
@@ -202,6 +246,12 @@ mod tests {
202246
true,
203247
),
204248
("abc:///run/c.sock", None, "", false),
249+
(
250+
"tcp://127.0.0.1:65500",
251+
Some(Domain::Tcp),
252+
"127.0.0.1:65500",
253+
true,
254+
),
205255
] {
206256
let (input, domain, addr, success) = (i.0, i.1, i.2, i.3);
207257
let r = parse_sockaddr(input);
@@ -229,6 +279,12 @@ mod tests {
229279
("Vsock:///run/c.sock", None, "", false),
230280
("unix://@/run/b.sock", None, "", false),
231281
("abc:///run/c.sock", None, "", false),
282+
(
283+
"tcp://127.0.0.1:65500",
284+
Some(Domain::Tcp),
285+
"127.0.0.1:65500",
286+
true,
287+
),
232288
] {
233289
let (input, domain, addr, success) = (i.0, i.1, i.2, i.3);
234290
let r = parse_sockaddr(input);

0 commit comments

Comments
 (0)