@@ -2,7 +2,7 @@ use crate::{
2
2
metrics:: { DURATION_WS , LIVE_WS } ,
3
3
parse_channel, parse_crate_type, parse_edition, parse_mode,
4
4
sandbox:: { self , Sandbox } ,
5
- Error , ExecutionSnafu , Result , SandboxCreationSnafu ,
5
+ Error , ExecutionSnafu , Result , SandboxCreationSnafu , WebSocketTaskPanicSnafu ,
6
6
} ;
7
7
8
8
use axum:: extract:: ws:: { Message , WebSocket } ;
@@ -11,7 +11,7 @@ use std::{
11
11
convert:: { TryFrom , TryInto } ,
12
12
time:: Instant ,
13
13
} ;
14
- use tokio:: sync:: mpsc;
14
+ use tokio:: { sync:: mpsc, task :: JoinSet } ;
15
15
16
16
#[ derive( serde:: Deserialize ) ]
17
17
#[ serde( tag = "type" ) ]
@@ -109,6 +109,9 @@ pub async fn handle(mut socket: WebSocket) {
109
109
let start = Instant :: now ( ) ;
110
110
111
111
let ( tx, mut rx) = mpsc:: channel ( 3 ) ;
112
+ let mut tasks = JoinSet :: new ( ) ;
113
+
114
+ // TODO: Implement some kind of timeout to shutdown running work?
112
115
113
116
loop {
114
117
tokio:: select! {
@@ -118,7 +121,7 @@ pub async fn handle(mut socket: WebSocket) {
118
121
// browser disconnected
119
122
break ;
120
123
}
121
- Some ( Ok ( Message :: Text ( txt) ) ) => handle_msg( txt, & tx) . await ,
124
+ Some ( Ok ( Message :: Text ( txt) ) ) => handle_msg( txt, & tx, & mut tasks ) . await ,
122
125
Some ( Ok ( _) ) => {
123
126
// unknown message type
124
127
continue ;
@@ -128,10 +131,31 @@ pub async fn handle(mut socket: WebSocket) {
128
131
} ,
129
132
resp = rx. recv( ) => {
130
133
let resp = resp. expect( "The rx should never close as we have a tx" ) ;
131
- let resp = resp. unwrap_or_else( |e| WSMessageResponse :: Error ( WSError { error: e. to_string( ) } ) ) ;
132
- const LAST_CHANCE_ERROR : & str = r#"{ "type": "WEBSOCKET_ERROR", "error": "Unable to serialize JSON" }"# ;
133
- let resp = serde_json:: to_string( & resp) . unwrap_or_else( |_| LAST_CHANCE_ERROR . into( ) ) ;
134
- let resp = Message :: Text ( resp) ;
134
+ let resp = resp. unwrap_or_else( error_to_response) ;
135
+ let resp = response_to_message( resp) ;
136
+
137
+ if let Err ( _) = socket. send( resp) . await {
138
+ // We can't send a response
139
+ break ;
140
+ }
141
+ } ,
142
+ // We don't care if there are no running tasks
143
+ Some ( task) = tasks. join_next( ) => {
144
+ let Err ( error) = task else { continue } ;
145
+ // The task was cancelled; no need to report
146
+ let Ok ( panic) = error. try_into_panic( ) else { continue } ;
147
+
148
+ let text = match panic. downcast:: <String >( ) {
149
+ Ok ( text) => * text,
150
+ Err ( panic) => match panic. downcast:: <& str >( ) {
151
+ Ok ( text) => text. to_string( ) ,
152
+ _ => "An unknown panic occurred" . into( ) ,
153
+ }
154
+ } ;
155
+ let error = WebSocketTaskPanicSnafu { text } . build( ) ;
156
+
157
+ let resp = error_to_response( error) ;
158
+ let resp = response_to_message( resp) ;
135
159
136
160
if let Err ( _) = socket. send( resp) . await {
137
161
// We can't send a response
@@ -141,22 +165,42 @@ pub async fn handle(mut socket: WebSocket) {
141
165
}
142
166
}
143
167
168
+ drop ( ( tx, rx, socket) ) ;
169
+ tasks. shutdown ( ) . await ;
170
+
144
171
LIVE_WS . dec ( ) ;
145
172
let elapsed = start. elapsed ( ) ;
146
173
DURATION_WS . observe ( elapsed. as_secs_f64 ( ) ) ;
147
174
}
148
175
149
- async fn handle_msg ( txt : String , tx : & mpsc:: Sender < Result < WSMessageResponse > > ) {
176
+ fn error_to_response ( error : Error ) -> WSMessageResponse {
177
+ let error = error. to_string ( ) ;
178
+ WSMessageResponse :: Error ( WSError { error } )
179
+ }
180
+
181
+ fn response_to_message ( response : WSMessageResponse ) -> Message {
182
+ const LAST_CHANCE_ERROR : & str =
183
+ r#"{ "type": "WEBSOCKET_ERROR", "error": "Unable to serialize JSON" }"# ;
184
+ let resp = serde_json:: to_string ( & response) . unwrap_or_else ( |_| LAST_CHANCE_ERROR . into ( ) ) ;
185
+ Message :: Text ( resp)
186
+ }
187
+
188
+ async fn handle_msg (
189
+ txt : String ,
190
+ tx : & mpsc:: Sender < Result < WSMessageResponse > > ,
191
+ tasks : & mut JoinSet < Result < ( ) > > ,
192
+ ) {
150
193
use WSMessageRequest :: * ;
151
194
152
195
let msg = serde_json:: from_str ( & txt) . context ( crate :: DeserializationSnafu ) ;
153
196
154
197
match msg {
155
198
Ok ( WSExecuteRequest ( req) ) => {
156
199
let tx = tx. clone ( ) ;
157
- tokio :: spawn ( async move {
200
+ tasks . spawn ( async move {
158
201
let resp = handle_execute ( req) . await ;
159
202
tx. send ( resp) . await . ok ( /* We don't care if the channel is closed */ ) ;
203
+ Ok ( ( ) )
160
204
} ) ;
161
205
}
162
206
Err ( e) => {
0 commit comments