Skip to content

Commit e2f71a4

Browse files
author
yngrtc
committed
add async_chat.rs
1 parent 0ffe539 commit e2f71a4

File tree

6 files changed

+836
-18
lines changed

6 files changed

+836
-18
lines changed

Cargo.toml

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,42 @@ sctp = { path = "rtc/sctp" }
3838
data = { path = "rtc/data" }
3939

4040
[dev-dependencies]
41-
# sfu impl
41+
# common
42+
chrono = "0.4"
43+
env_logger = "0.11"
4244
clap = { version = "4.4.12", features = ["derive"] }
4345
anyhow = "1.0.78"
46+
rouille = { version = "3.6.2", features = ["ssl"] }
47+
systemstat = "0.2.2"
48+
49+
# sync_chat
4450
wg = "0.7"
4551
crossbeam-channel = "0.5.11"
4652
ctrlc = "3.4.2"
4753

48-
# str0m impl
49-
rouille = { version = "3.6.2", features = ["ssl"] }
50-
systemstat = "0.2.2"
54+
# async_chat
55+
futures = "0.3.30"
56+
smol = "2.0.0"
57+
async-broadcast = "0.6.0"
58+
waitgroup = "0.1.2"
59+
core_affinity = "0.8.1"
60+
num_cpus = "1.16.0"
61+
tokio = { version = "1.36", features = ["full"] }
62+
tokio-util = "0.7"
5163

5264
# tests
5365
webrtc = { path = "webrtc/webrtc" }
54-
tokio = { version = "1.36", features = ["full"] }
55-
tokio-util = "0.7"
56-
chrono = "0.4"
57-
env_logger = "0.11"
5866
hyper = { version = "0.14.28", features = ["full"] }
5967

6068
[[example]]
6169
name = "sync_chat"
6270
path = "examples/sync_chat.rs"
6371
test = false
6472
bench = false
73+
74+
[[example]]
75+
name = "async_chat"
76+
path = "examples/async_chat.rs"
77+
test = false
78+
bench = false
79+

examples/async_chat.rs

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
extern crate num_cpus;
2+
3+
use std::cell::RefCell;
4+
use std::collections::HashMap;
5+
use std::io::Write;
6+
use std::net::SocketAddr;
7+
use std::rc::Rc;
8+
use std::str::FromStr;
9+
use std::sync::Arc;
10+
11+
use async_broadcast::broadcast;
12+
use clap::Parser;
13+
use dtls::extension::extension_use_srtp::SrtpProtectionProfile;
14+
use log::{error, info};
15+
use retty::bootstrap::BootstrapUdpServer;
16+
use retty::channel::Pipeline;
17+
use retty::executor::LocalExecutorBuilder;
18+
use retty::transport::{AsyncTransport, AsyncTransportWrite, TaggedBytesMut};
19+
use waitgroup::WaitGroup;
20+
21+
use sfu::{
22+
DataChannelHandler, DemuxerHandler, DtlsHandler, ExceptionHandler, GatewayHandler,
23+
InterceptorHandler, RTCCertificate, SctpHandler, ServerConfig, ServerStates, SrtpHandler,
24+
StunHandler,
25+
};
26+
27+
mod async_signal;
28+
29+
use async_signal::{handle_signaling_message, SignalingMessage, SignalingServer};
30+
31+
#[derive(Default, Debug, Copy, Clone, clap::ValueEnum)]
32+
enum Level {
33+
Error,
34+
Warn,
35+
#[default]
36+
Info,
37+
Debug,
38+
Trace,
39+
}
40+
41+
impl From<Level> for log::LevelFilter {
42+
fn from(level: Level) -> Self {
43+
match level {
44+
Level::Error => log::LevelFilter::Error,
45+
Level::Warn => log::LevelFilter::Warn,
46+
Level::Info => log::LevelFilter::Info,
47+
Level::Debug => log::LevelFilter::Debug,
48+
Level::Trace => log::LevelFilter::Trace,
49+
}
50+
}
51+
}
52+
53+
#[derive(Parser)]
54+
#[command(name = "SFU Server")]
55+
#[command(author = "Rusty Rain <y@ngr.tc>")]
56+
#[command(version = "0.1.0")]
57+
#[command(about = "An example of SFU Server", long_about = None)]
58+
struct Cli {
59+
#[arg(long, default_value_t = format!("127.0.0.1"))]
60+
host: String,
61+
#[arg(short, long, default_value_t = 8080)]
62+
signal_port: u16,
63+
#[arg(long, default_value_t = 3478)]
64+
media_port_min: u16,
65+
#[arg(long, default_value_t = 3495)]
66+
media_port_max: u16,
67+
68+
#[arg(short, long)]
69+
debug: bool,
70+
#[arg(short, long, default_value_t = Level::Info)]
71+
#[clap(value_enum)]
72+
level: Level,
73+
}
74+
75+
fn main() -> anyhow::Result<()> {
76+
let cli = Cli::parse();
77+
if cli.debug {
78+
env_logger::Builder::new()
79+
.format(|buf, record| {
80+
writeln!(
81+
buf,
82+
"{}:{} [{}] {} - {}",
83+
record.file().unwrap_or("unknown"),
84+
record.line().unwrap_or(0),
85+
record.level(),
86+
chrono::Local::now().format("%H:%M:%S.%6f"),
87+
record.args()
88+
)
89+
})
90+
.filter(None, cli.level.into())
91+
.init();
92+
}
93+
94+
println!(
95+
"listening {}:{}(signal)/[{}-{}](media)...",
96+
cli.host, cli.signal_port, cli.media_port_min, cli.media_port_max
97+
);
98+
99+
let media_ports: Vec<u16> = (cli.media_port_min..=cli.media_port_max).collect();
100+
let (stop_tx, mut stop_rx) = broadcast::<()>(1);
101+
let mut media_port_thread_map = HashMap::new();
102+
103+
let key_pair = rcgen::KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?;
104+
let certificates = vec![RTCCertificate::from_key_pair(key_pair)?];
105+
let dtls_handshake_config = Arc::new(
106+
dtls::config::ConfigBuilder::default()
107+
.with_certificates(
108+
certificates
109+
.iter()
110+
.map(|c| c.dtls_certificate.clone())
111+
.collect(),
112+
)
113+
.with_srtp_protection_profiles(vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80])
114+
.with_extended_master_secret(dtls::config::ExtendedMasterSecretType::Require)
115+
.build(false, None)?,
116+
);
117+
let sctp_endpoint_config = Arc::new(sctp::EndpointConfig::default());
118+
let sctp_server_config = Arc::new(sctp::ServerConfig::default());
119+
let server_config = Arc::new(
120+
ServerConfig::new(certificates)
121+
.with_dtls_handshake_config(dtls_handshake_config)
122+
.with_sctp_endpoint_config(sctp_endpoint_config)
123+
.with_sctp_server_config(sctp_server_config),
124+
);
125+
let wait_group = WaitGroup::new();
126+
let core_num = num_cpus::get();
127+
128+
for port in media_ports {
129+
let worker = wait_group.worker();
130+
let host = cli.host.clone();
131+
let mut stop_rx = stop_rx.clone();
132+
let (signaling_tx, signaling_rx) = smol::channel::unbounded::<SignalingMessage>();
133+
media_port_thread_map.insert(port, signaling_tx);
134+
135+
let server_config = server_config.clone();
136+
LocalExecutorBuilder::new()
137+
.name(format!("media_port_{}", port).as_str())
138+
.core_id(core_affinity::CoreId {
139+
id: (port as usize) % core_num,
140+
})
141+
.spawn(move || async move {
142+
let _worker = worker;
143+
let local_addr = SocketAddr::from_str(&format!("{}:{}", host, port)).unwrap();
144+
let server_states = Rc::new(RefCell::new(ServerStates::new(server_config, local_addr).unwrap()));
145+
146+
info!("listening {}:{}...", host, port);
147+
148+
let server_states_moved = server_states.clone();
149+
let mut bootstrap = BootstrapUdpServer::new();
150+
bootstrap.pipeline(Box::new(
151+
move |writer: AsyncTransportWrite<TaggedBytesMut>| {
152+
let pipeline: Pipeline<TaggedBytesMut, TaggedBytesMut> = Pipeline::new();
153+
154+
let local_addr = writer.get_local_addr();
155+
let async_transport_handler = AsyncTransport::new(writer);
156+
let demuxer_handler = DemuxerHandler::new();
157+
let write_exception_handler = ExceptionHandler::new();
158+
let stun_handler = StunHandler::new();
159+
// DTLS
160+
let dtls_handler = DtlsHandler::new(local_addr, Rc::clone(&server_states_moved));
161+
let sctp_handler = SctpHandler::new(local_addr, Rc::clone(&server_states_moved));
162+
let data_channel_handler = DataChannelHandler::new();
163+
// SRTP
164+
let srtp_handler = SrtpHandler::new(Rc::clone(&server_states_moved));
165+
let interceptor_handler = InterceptorHandler::new(Rc::clone(&server_states_moved));
166+
// Gateway
167+
let gateway_handler = GatewayHandler::new(Rc::clone(&server_states_moved));
168+
let read_exception_handler = ExceptionHandler::new();
169+
170+
pipeline.add_back(async_transport_handler);
171+
pipeline.add_back(demuxer_handler);
172+
pipeline.add_back(write_exception_handler);
173+
pipeline.add_back(stun_handler);
174+
// DTLS
175+
pipeline.add_back(dtls_handler);
176+
pipeline.add_back(sctp_handler);
177+
pipeline.add_back(data_channel_handler);
178+
// SRTP
179+
pipeline.add_back(srtp_handler);
180+
pipeline.add_back(interceptor_handler);
181+
// Gateway
182+
pipeline.add_back(gateway_handler);
183+
pipeline.add_back(read_exception_handler);
184+
185+
pipeline.finalize()
186+
},
187+
));
188+
189+
if let Err(err) = bootstrap.bind(format!("{}:{}", host, port)).await {
190+
error!("bootstrap binding error: {}", err);
191+
return;
192+
}
193+
194+
loop {
195+
tokio::select! {
196+
_ = stop_rx.recv() => {
197+
info!("media server on {}:{} receives stop signal", host, port);
198+
break;
199+
}
200+
recv = signaling_rx.recv() => {
201+
match recv {
202+
Ok(signaling_msg) => {
203+
if let Err(err) = handle_signaling_message(&server_states, signaling_msg) {
204+
error!("handle_signaling_message error: {}", err);
205+
}
206+
}
207+
Err(err) => {
208+
error!("signal_rx recv error: {}", err);
209+
break;
210+
}
211+
}
212+
}
213+
}
214+
}
215+
216+
bootstrap.graceful_stop().await;
217+
info!("media server on {}:{} is gracefully down", host, port);
218+
})?;
219+
}
220+
221+
let signaling_addr = SocketAddr::from_str(&format!("{}:{}", cli.host, cli.signal_port))?;
222+
let signaling_stop_rx = stop_rx.clone();
223+
let signaling_handle = std::thread::spawn(move || {
224+
let rt = tokio::runtime::Builder::new_current_thread()
225+
.enable_io()
226+
.enable_time()
227+
.build()
228+
.unwrap();
229+
230+
rt.block_on(async {
231+
let signaling_server = SignalingServer::new(signaling_addr, media_port_thread_map);
232+
let mut done_rx = signaling_server.run(signaling_stop_rx).await;
233+
let _ = done_rx.recv().await;
234+
wait_group.wait().await;
235+
info!("signaling server is gracefully down");
236+
})
237+
});
238+
239+
LocalExecutorBuilder::default().run(async move {
240+
println!("Press Ctrl-C to stop");
241+
std::thread::spawn(move || {
242+
let mut stop_tx = Some(stop_tx);
243+
ctrlc::set_handler(move || {
244+
if let Some(stop_tx) = stop_tx.take() {
245+
let _ = stop_tx.try_broadcast(());
246+
}
247+
})
248+
.expect("Error setting Ctrl-C handler");
249+
});
250+
let _ = stop_rx.recv().await;
251+
println!("Wait for Signaling Sever and Media Server Gracefully Shutdown...");
252+
});
253+
254+
let _ = signaling_handle.join();
255+
256+
Ok(())
257+
}

0 commit comments

Comments
 (0)