1
1
use libc:: { ECONNABORTED , EMFILE , ENFILE , ENOBUFS , ENOMEM } ;
2
2
use timestamped_socket:: interface:: ChangeDetector ;
3
- use tokio:: io:: AsyncWriteExt ;
3
+ use tokio:: io:: { AsyncReadExt , AsyncWriteExt } ;
4
4
use tokio:: net:: TcpListener ;
5
5
use tracing:: { debug, error, trace, warn} ;
6
6
@@ -211,9 +211,43 @@ async fn run(options: NtpMetricsExporterOptions) -> Result<(), Box<dyn std::erro
211
211
}
212
212
213
213
async fn handle_connection (
214
- stream : & mut ( impl tokio:: io:: AsyncWrite + Unpin ) ,
214
+ stream : & mut ( impl tokio:: io:: AsyncWrite + tokio :: io :: AsyncRead + Unpin ) ,
215
215
observation_socket_path : & Path ,
216
216
) -> std:: io:: Result < ( ) > {
217
+ // Wait until a request was sent, dropping the bytes read when this scope ends
218
+ // to ensure we don't accidentally use them afterwards
219
+ {
220
+ // Receive all data until the header was fully received, or until max buf size
221
+ let mut buf = [ 0u8 ; 2048 ] ;
222
+ let mut bytes_read = 0 ;
223
+ loop {
224
+ bytes_read += stream. read ( & mut buf[ bytes_read..] ) . await ?;
225
+
226
+ // The headers end with two CRLFs in a row
227
+ if buf[ 0 ..bytes_read] . windows ( 4 ) . any ( |w| w == b"\r \n \r \n " ) {
228
+ break ;
229
+ }
230
+
231
+ // Headers should easily fit within the buffer
232
+ // If we have not found the end yet, we are not going to
233
+ if bytes_read >= buf. len ( ) {
234
+ return Err ( std:: io:: Error :: new (
235
+ std:: io:: ErrorKind :: InvalidInput ,
236
+ "Request too long" ,
237
+ ) ) ;
238
+ }
239
+ }
240
+
241
+ // We only respond to GET requests
242
+ if !buf[ 0 ..bytes_read] . starts_with ( b"GET " ) {
243
+ return Err ( std:: io:: Error :: new (
244
+ std:: io:: ErrorKind :: InvalidInput ,
245
+ "Expected GET request" ,
246
+ ) ) ;
247
+ }
248
+ }
249
+
250
+ // Send the response
217
251
let mut buf = String :: with_capacity ( 4 * 1024 ) ;
218
252
match handler ( & mut buf, observation_socket_path) . await {
219
253
Ok ( ( ) ) => {
@@ -261,6 +295,8 @@ fn format_response(buf: &mut String, state: &ObservableState) -> std::fmt::Resul
261
295
262
296
#[ cfg( test) ]
263
297
mod tests {
298
+ use std:: io:: Cursor ;
299
+
264
300
use super :: * ;
265
301
266
302
const BINARY : & str = "/usr/bin/ntp-metrics-exporter" ;
@@ -274,4 +310,24 @@ mod tests {
274
310
let options = NtpMetricsExporterOptions :: try_parse_from ( arguments) . unwrap ( ) ;
275
311
assert_eq ! ( options. config. unwrap( ) . as_path( ) , config) ;
276
312
}
313
+
314
+ #[ tokio:: test]
315
+ async fn deny_non_get_request ( ) {
316
+ let mut example = b"POST / HTTP/1.1\r \n \r \n " . to_vec ( ) ;
317
+ let mut cursor = Cursor :: new ( & mut example) ;
318
+ let res = handle_connection ( & mut cursor, Path :: new ( "/tmp/ntpd-rs.sock" ) ) . await ;
319
+ let err = res. unwrap_err ( ) ;
320
+ assert_eq ! ( err. kind( ) , std:: io:: ErrorKind :: InvalidInput ) ;
321
+ assert_eq ! ( err. to_string( ) , "Expected GET request" ) ;
322
+ }
323
+
324
+ #[ tokio:: test]
325
+ async fn does_not_accept_large_requests ( ) {
326
+ let mut example = [ 1u8 ; 4096 ] . to_vec ( ) ;
327
+ let mut cursor = Cursor :: new ( & mut example) ;
328
+ let res = handle_connection ( & mut cursor, Path :: new ( "/tmp/ntpd-rs.sock" ) ) . await ;
329
+ let err = res. unwrap_err ( ) ;
330
+ assert_eq ! ( err. kind( ) , std:: io:: ErrorKind :: InvalidInput ) ;
331
+ assert_eq ! ( err. to_string( ) , "Request too long" ) ;
332
+ }
277
333
}
0 commit comments