|
| 1 | +extern crate num_cpus; |
| 2 | + |
| 3 | +use std::cell::RefCell; |
| 4 | +use std::collections::HashMap; |
| 5 | +use std::io::Write; |
| 6 | +use std::net::SocketAddr; |
| 7 | +use std::rc::Rc; |
| 8 | +use std::str::FromStr; |
| 9 | +use std::sync::Arc; |
| 10 | + |
| 11 | +use async_broadcast::broadcast; |
| 12 | +use clap::Parser; |
| 13 | +use dtls::extension::extension_use_srtp::SrtpProtectionProfile; |
| 14 | +use log::{error, info}; |
| 15 | +use retty::bootstrap::BootstrapUdpServer; |
| 16 | +use retty::channel::Pipeline; |
| 17 | +use retty::executor::LocalExecutorBuilder; |
| 18 | +use retty::transport::{AsyncTransport, AsyncTransportWrite, TaggedBytesMut}; |
| 19 | +use waitgroup::WaitGroup; |
| 20 | + |
| 21 | +use sfu::{ |
| 22 | + DataChannelHandler, DemuxerHandler, DtlsHandler, ExceptionHandler, GatewayHandler, |
| 23 | + InterceptorHandler, RTCCertificate, SctpHandler, ServerConfig, ServerStates, SrtpHandler, |
| 24 | + StunHandler, |
| 25 | +}; |
| 26 | + |
| 27 | +mod async_signal; |
| 28 | + |
| 29 | +use async_signal::{handle_signaling_message, SignalingMessage, SignalingServer}; |
| 30 | + |
| 31 | +#[derive(Default, Debug, Copy, Clone, clap::ValueEnum)] |
| 32 | +enum Level { |
| 33 | + Error, |
| 34 | + Warn, |
| 35 | + #[default] |
| 36 | + Info, |
| 37 | + Debug, |
| 38 | + Trace, |
| 39 | +} |
| 40 | + |
| 41 | +impl From<Level> for log::LevelFilter { |
| 42 | + fn from(level: Level) -> Self { |
| 43 | + match level { |
| 44 | + Level::Error => log::LevelFilter::Error, |
| 45 | + Level::Warn => log::LevelFilter::Warn, |
| 46 | + Level::Info => log::LevelFilter::Info, |
| 47 | + Level::Debug => log::LevelFilter::Debug, |
| 48 | + Level::Trace => log::LevelFilter::Trace, |
| 49 | + } |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +#[derive(Parser)] |
| 54 | +#[command(name = "SFU Server")] |
| 55 | +#[command(author = "Rusty Rain <y@ngr.tc>")] |
| 56 | +#[command(version = "0.1.0")] |
| 57 | +#[command(about = "An example of SFU Server", long_about = None)] |
| 58 | +struct Cli { |
| 59 | + #[arg(long, default_value_t = format!("127.0.0.1"))] |
| 60 | + host: String, |
| 61 | + #[arg(short, long, default_value_t = 8080)] |
| 62 | + signal_port: u16, |
| 63 | + #[arg(long, default_value_t = 3478)] |
| 64 | + media_port_min: u16, |
| 65 | + #[arg(long, default_value_t = 3495)] |
| 66 | + media_port_max: u16, |
| 67 | + |
| 68 | + #[arg(short, long)] |
| 69 | + debug: bool, |
| 70 | + #[arg(short, long, default_value_t = Level::Info)] |
| 71 | + #[clap(value_enum)] |
| 72 | + level: Level, |
| 73 | +} |
| 74 | + |
| 75 | +fn main() -> anyhow::Result<()> { |
| 76 | + let cli = Cli::parse(); |
| 77 | + if cli.debug { |
| 78 | + env_logger::Builder::new() |
| 79 | + .format(|buf, record| { |
| 80 | + writeln!( |
| 81 | + buf, |
| 82 | + "{}:{} [{}] {} - {}", |
| 83 | + record.file().unwrap_or("unknown"), |
| 84 | + record.line().unwrap_or(0), |
| 85 | + record.level(), |
| 86 | + chrono::Local::now().format("%H:%M:%S.%6f"), |
| 87 | + record.args() |
| 88 | + ) |
| 89 | + }) |
| 90 | + .filter(None, cli.level.into()) |
| 91 | + .init(); |
| 92 | + } |
| 93 | + |
| 94 | + println!( |
| 95 | + "listening {}:{}(signal)/[{}-{}](media)...", |
| 96 | + cli.host, cli.signal_port, cli.media_port_min, cli.media_port_max |
| 97 | + ); |
| 98 | + |
| 99 | + let media_ports: Vec<u16> = (cli.media_port_min..=cli.media_port_max).collect(); |
| 100 | + let (stop_tx, mut stop_rx) = broadcast::<()>(1); |
| 101 | + let mut media_port_thread_map = HashMap::new(); |
| 102 | + |
| 103 | + let key_pair = rcgen::KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?; |
| 104 | + let certificates = vec![RTCCertificate::from_key_pair(key_pair)?]; |
| 105 | + let dtls_handshake_config = Arc::new( |
| 106 | + dtls::config::ConfigBuilder::default() |
| 107 | + .with_certificates( |
| 108 | + certificates |
| 109 | + .iter() |
| 110 | + .map(|c| c.dtls_certificate.clone()) |
| 111 | + .collect(), |
| 112 | + ) |
| 113 | + .with_srtp_protection_profiles(vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80]) |
| 114 | + .with_extended_master_secret(dtls::config::ExtendedMasterSecretType::Require) |
| 115 | + .build(false, None)?, |
| 116 | + ); |
| 117 | + let sctp_endpoint_config = Arc::new(sctp::EndpointConfig::default()); |
| 118 | + let sctp_server_config = Arc::new(sctp::ServerConfig::default()); |
| 119 | + let server_config = Arc::new( |
| 120 | + ServerConfig::new(certificates) |
| 121 | + .with_dtls_handshake_config(dtls_handshake_config) |
| 122 | + .with_sctp_endpoint_config(sctp_endpoint_config) |
| 123 | + .with_sctp_server_config(sctp_server_config), |
| 124 | + ); |
| 125 | + let wait_group = WaitGroup::new(); |
| 126 | + let core_num = num_cpus::get(); |
| 127 | + |
| 128 | + for port in media_ports { |
| 129 | + let worker = wait_group.worker(); |
| 130 | + let host = cli.host.clone(); |
| 131 | + let mut stop_rx = stop_rx.clone(); |
| 132 | + let (signaling_tx, signaling_rx) = smol::channel::unbounded::<SignalingMessage>(); |
| 133 | + media_port_thread_map.insert(port, signaling_tx); |
| 134 | + |
| 135 | + let server_config = server_config.clone(); |
| 136 | + LocalExecutorBuilder::new() |
| 137 | + .name(format!("media_port_{}", port).as_str()) |
| 138 | + .core_id(core_affinity::CoreId { |
| 139 | + id: (port as usize) % core_num, |
| 140 | + }) |
| 141 | + .spawn(move || async move { |
| 142 | + let _worker = worker; |
| 143 | + let local_addr = SocketAddr::from_str(&format!("{}:{}", host, port)).unwrap(); |
| 144 | + let server_states = Rc::new(RefCell::new(ServerStates::new(server_config, local_addr).unwrap())); |
| 145 | + |
| 146 | + info!("listening {}:{}...", host, port); |
| 147 | + |
| 148 | + let server_states_moved = server_states.clone(); |
| 149 | + let mut bootstrap = BootstrapUdpServer::new(); |
| 150 | + bootstrap.pipeline(Box::new( |
| 151 | + move |writer: AsyncTransportWrite<TaggedBytesMut>| { |
| 152 | + let pipeline: Pipeline<TaggedBytesMut, TaggedBytesMut> = Pipeline::new(); |
| 153 | + |
| 154 | + let local_addr = writer.get_local_addr(); |
| 155 | + let async_transport_handler = AsyncTransport::new(writer); |
| 156 | + let demuxer_handler = DemuxerHandler::new(); |
| 157 | + let write_exception_handler = ExceptionHandler::new(); |
| 158 | + let stun_handler = StunHandler::new(); |
| 159 | + // DTLS |
| 160 | + let dtls_handler = DtlsHandler::new(local_addr, Rc::clone(&server_states_moved)); |
| 161 | + let sctp_handler = SctpHandler::new(local_addr, Rc::clone(&server_states_moved)); |
| 162 | + let data_channel_handler = DataChannelHandler::new(); |
| 163 | + // SRTP |
| 164 | + let srtp_handler = SrtpHandler::new(Rc::clone(&server_states_moved)); |
| 165 | + let interceptor_handler = InterceptorHandler::new(Rc::clone(&server_states_moved)); |
| 166 | + // Gateway |
| 167 | + let gateway_handler = GatewayHandler::new(Rc::clone(&server_states_moved)); |
| 168 | + let read_exception_handler = ExceptionHandler::new(); |
| 169 | + |
| 170 | + pipeline.add_back(async_transport_handler); |
| 171 | + pipeline.add_back(demuxer_handler); |
| 172 | + pipeline.add_back(write_exception_handler); |
| 173 | + pipeline.add_back(stun_handler); |
| 174 | + // DTLS |
| 175 | + pipeline.add_back(dtls_handler); |
| 176 | + pipeline.add_back(sctp_handler); |
| 177 | + pipeline.add_back(data_channel_handler); |
| 178 | + // SRTP |
| 179 | + pipeline.add_back(srtp_handler); |
| 180 | + pipeline.add_back(interceptor_handler); |
| 181 | + // Gateway |
| 182 | + pipeline.add_back(gateway_handler); |
| 183 | + pipeline.add_back(read_exception_handler); |
| 184 | + |
| 185 | + pipeline.finalize() |
| 186 | + }, |
| 187 | + )); |
| 188 | + |
| 189 | + if let Err(err) = bootstrap.bind(format!("{}:{}", host, port)).await { |
| 190 | + error!("bootstrap binding error: {}", err); |
| 191 | + return; |
| 192 | + } |
| 193 | + |
| 194 | + loop { |
| 195 | + tokio::select! { |
| 196 | + _ = stop_rx.recv() => { |
| 197 | + info!("media server on {}:{} receives stop signal", host, port); |
| 198 | + break; |
| 199 | + } |
| 200 | + recv = signaling_rx.recv() => { |
| 201 | + match recv { |
| 202 | + Ok(signaling_msg) => { |
| 203 | + if let Err(err) = handle_signaling_message(&server_states, signaling_msg) { |
| 204 | + error!("handle_signaling_message error: {}", err); |
| 205 | + } |
| 206 | + } |
| 207 | + Err(err) => { |
| 208 | + error!("signal_rx recv error: {}", err); |
| 209 | + break; |
| 210 | + } |
| 211 | + } |
| 212 | + } |
| 213 | + } |
| 214 | + } |
| 215 | + |
| 216 | + bootstrap.graceful_stop().await; |
| 217 | + info!("media server on {}:{} is gracefully down", host, port); |
| 218 | + })?; |
| 219 | + } |
| 220 | + |
| 221 | + let signaling_addr = SocketAddr::from_str(&format!("{}:{}", cli.host, cli.signal_port))?; |
| 222 | + let signaling_stop_rx = stop_rx.clone(); |
| 223 | + let signaling_handle = std::thread::spawn(move || { |
| 224 | + let rt = tokio::runtime::Builder::new_current_thread() |
| 225 | + .enable_io() |
| 226 | + .enable_time() |
| 227 | + .build() |
| 228 | + .unwrap(); |
| 229 | + |
| 230 | + rt.block_on(async { |
| 231 | + let signaling_server = SignalingServer::new(signaling_addr, media_port_thread_map); |
| 232 | + let mut done_rx = signaling_server.run(signaling_stop_rx).await; |
| 233 | + let _ = done_rx.recv().await; |
| 234 | + wait_group.wait().await; |
| 235 | + info!("signaling server is gracefully down"); |
| 236 | + }) |
| 237 | + }); |
| 238 | + |
| 239 | + LocalExecutorBuilder::default().run(async move { |
| 240 | + println!("Press Ctrl-C to stop"); |
| 241 | + std::thread::spawn(move || { |
| 242 | + let mut stop_tx = Some(stop_tx); |
| 243 | + ctrlc::set_handler(move || { |
| 244 | + if let Some(stop_tx) = stop_tx.take() { |
| 245 | + let _ = stop_tx.try_broadcast(()); |
| 246 | + } |
| 247 | + }) |
| 248 | + .expect("Error setting Ctrl-C handler"); |
| 249 | + }); |
| 250 | + let _ = stop_rx.recv().await; |
| 251 | + println!("Wait for Signaling Sever and Media Server Gracefully Shutdown..."); |
| 252 | + }); |
| 253 | + |
| 254 | + let _ = signaling_handle.join(); |
| 255 | + |
| 256 | + Ok(()) |
| 257 | +} |
0 commit comments