|  | 
| 5 | 5 | 
 | 
| 6 | 6 | use std::collections::HashMap; | 
| 7 | 7 | use std::sync::atomic::{AtomicUsize, Ordering}; | 
|  | 8 | +use std::sync::Arc; | 
| 8 | 9 | 
 | 
| 9 |  | -use futures_util::{FutureExt, StreamExt}; | 
| 10 | 10 | use once_cell::sync::Lazy; | 
| 11 | 11 | use tokio::sync::{mpsc, RwLock}; | 
| 12 |  | -use tokio_stream::wrappers::UnboundedReceiverStream; | 
| 13 | 12 | 
 | 
| 14 | 13 | use silent::prelude::*; | 
| 15 | 14 | 
 | 
| 16 |  | -type Users = RwLock<HashMap<usize, mpsc::UnboundedSender<Result<Message>>>>; | 
|  | 15 | +type Users = RwLock<HashMap<usize, mpsc::UnboundedSender<Message>>>; | 
| 17 | 16 | 
 | 
| 18 | 17 | static NEXT_USER_ID: AtomicUsize = AtomicUsize::new(1); | 
| 19 | 18 | static ONLINE_USERS: Lazy<Users> = Lazy::new(Users::default); | 
| 20 | 19 | 
 | 
| 21 | 20 | fn main() { | 
| 22 | 21 |     logger::fmt().init(); | 
| 23 |  | -    let route = Route::new("") | 
| 24 |  | -        .get(index) | 
| 25 |  | -        .append(Route::new("chat").ws(None, handle_socket)); | 
|  | 22 | +    let route = Route::new("").get(index).append( | 
|  | 23 | +        Route::new("chat").ws( | 
|  | 24 | +            None, | 
|  | 25 | +            WebSocketHandler::new() | 
|  | 26 | +                .on_connect(on_connect) | 
|  | 27 | +                .on_send(on_send) | 
|  | 28 | +                .on_receive(on_receive) | 
|  | 29 | +                .on_close(on_close), | 
|  | 30 | +        ), | 
|  | 31 | +    ); | 
| 26 | 32 |     Server::new().bind_route(route).run(); | 
| 27 | 33 | } | 
| 28 | 34 | 
 | 
| 29 |  | -async fn handle_socket(ws: WebSocket) { | 
| 30 |  | -    // Use a counter to assign a new unique ID for this user. | 
|  | 35 | +async fn on_connect( | 
|  | 36 | +    parts: Arc<RwLock<WebSocketParts>>, | 
|  | 37 | +    sender: mpsc::UnboundedSender<Message>, | 
|  | 38 | +) -> Result<()> { | 
|  | 39 | +    let mut parts = parts.write().await; | 
|  | 40 | +    info!("{:?}", parts); | 
| 31 | 41 |     let my_id = NEXT_USER_ID.fetch_add(1, Ordering::Relaxed); | 
| 32 |  | - | 
| 33 | 42 |     info!("new chat user: {}", my_id); | 
|  | 43 | +    parts.extensions_mut().insert(my_id); | 
|  | 44 | +    sender | 
|  | 45 | +        .send(Message::text(format!("Hello User#{my_id}"))) | 
|  | 46 | +        .unwrap(); | 
|  | 47 | +    ONLINE_USERS.write().await.insert(my_id, sender); | 
|  | 48 | +    Ok(()) | 
|  | 49 | +} | 
| 34 | 50 | 
 | 
| 35 |  | -    // Split the socket into a sender and receive of messages. | 
| 36 |  | -    let (user_ws_tx, mut user_ws_rx) = ws.split(); | 
| 37 |  | - | 
| 38 |  | -    // Use an unbounded channel to handle buffering and flushing of messages | 
| 39 |  | -    // to the websocket... | 
| 40 |  | -    let (tx, rx) = mpsc::unbounded_channel(); | 
| 41 |  | -    let rx = UnboundedReceiverStream::new(rx); | 
| 42 |  | -    let fut = rx.forward(user_ws_tx).map(|result| { | 
| 43 |  | -        if let Err(e) = result { | 
| 44 |  | -            error!(error = ?e, "websocket send error"); | 
| 45 |  | -        } | 
| 46 |  | -    }); | 
| 47 |  | -    tokio::task::spawn(fut); | 
| 48 |  | -    let fut = async move { | 
| 49 |  | -        ONLINE_USERS.write().await.insert(my_id, tx); | 
| 50 |  | - | 
| 51 |  | -        while let Some(result) = user_ws_rx.next().await { | 
| 52 |  | -            let msg = match result { | 
| 53 |  | -                Ok(msg) => msg, | 
| 54 |  | -                Err(e) => { | 
| 55 |  | -                    eprintln!("websocket error(uid={my_id}): {e}"); | 
| 56 |  | -                    break; | 
| 57 |  | -                } | 
| 58 |  | -            }; | 
| 59 |  | -            user_message(my_id, msg).await; | 
| 60 |  | -        } | 
| 61 |  | - | 
| 62 |  | -        user_disconnected(my_id).await; | 
| 63 |  | -    }; | 
| 64 |  | -    tokio::task::spawn(fut); | 
|  | 51 | +async fn on_send(message: Message, _parts: Arc<RwLock<WebSocketParts>>) -> Result<Message> { | 
|  | 52 | +    info!("on_send: {:?}", message); | 
|  | 53 | +    Ok(message) | 
| 65 | 54 | } | 
| 66 | 55 | 
 | 
| 67 |  | -async fn user_message(my_id: usize, msg: Message) { | 
| 68 |  | -    let msg = if let Ok(s) = msg.to_str() { | 
|  | 56 | +async fn on_receive(message: Message, parts: Arc<RwLock<WebSocketParts>>) -> Result<()> { | 
|  | 57 | +    let parts = parts.read().await; | 
|  | 58 | +    let my_id = parts.extensions().get::<usize>().unwrap(); | 
|  | 59 | +    info!("on_receive: {:?}", message); | 
|  | 60 | +    let msg = if let Ok(s) = message.to_str() { | 
| 69 | 61 |         s | 
| 70 | 62 |     } else { | 
| 71 |  | -        return; | 
|  | 63 | +        return Err(SilentError::BusinessError { | 
|  | 64 | +            code: StatusCode::BAD_REQUEST, | 
|  | 65 | +            msg: "invalid message".to_string(), | 
|  | 66 | +        }); | 
| 72 | 67 |     }; | 
| 73 |  | - | 
| 74 |  | -    let new_msg = format!("<User#{my_id}>: {msg}"); | 
| 75 |  | - | 
| 76 |  | -    // New message from this user, send it to everyone else (except same uid)... | 
| 77 |  | -    for (&uid, tx) in ONLINE_USERS.read().await.iter() { | 
|  | 68 | +    let message = Message::text(format!("<User#{my_id}>: {msg}")); | 
|  | 69 | +    for (uid, tx) in ONLINE_USERS.read().await.iter() { | 
| 78 | 70 |         if my_id != uid { | 
| 79 |  | -            if let Err(_disconnected) = tx.send(Ok(Message::text(new_msg.clone()))) { | 
| 80 |  | -                // The tx is disconnected, our `user_disconnected` code | 
| 81 |  | -                // should be happening in another task, nothing more to | 
| 82 |  | -                // do here. | 
| 83 |  | -            } | 
|  | 71 | +            if let Err(_disconnected) = tx.send(message.clone()) {} | 
| 84 | 72 |         } | 
| 85 | 73 |     } | 
|  | 74 | +    Ok(()) | 
| 86 | 75 | } | 
| 87 | 76 | 
 | 
| 88 |  | -async fn user_disconnected(my_id: usize) { | 
| 89 |  | -    eprintln!("good bye user: {my_id}"); | 
|  | 77 | +async fn on_close(parts: Arc<RwLock<WebSocketParts>>) { | 
|  | 78 | +    let parts = parts.read().await; | 
|  | 79 | +    let my_id = parts.extensions().get::<usize>().unwrap(); | 
|  | 80 | +    info!("good bye user: {my_id}"); | 
| 90 | 81 |     // Stream closed up, so remove from the user list | 
| 91 |  | -    ONLINE_USERS.write().await.remove(&my_id); | 
|  | 82 | +    ONLINE_USERS.write().await.remove(my_id); | 
| 92 | 83 | } | 
| 93 | 84 | 
 | 
| 94 | 85 | async fn index<'a>(_res: Request) -> Result<&'a str> { | 
|  | 
0 commit comments