@@ -17,6 +17,7 @@ use std::{
17
17
boxed,
18
18
fmt:: { self , Debug , Display , Formatter } ,
19
19
future:: Future ,
20
+ io:: ErrorKind ,
20
21
mem,
21
22
pin:: Pin ,
22
23
str:: FromStr ,
@@ -59,6 +60,10 @@ pub trait Client: Send + Sync + private::Sealed {
59
60
* TODO specify list of stati to not retry (e.g. 204)
60
61
*/
61
62
63
+ /// Maximum amount of redirects that the client will follow before
64
+ /// giving up, if not overridden via [ClientBuilder::redirect_limit].
65
+ pub const DEFAULT_REDIRECT_LIMIT : u32 = 16 ;
66
+
62
67
/// ClientBuilder provides a series of builder methods to easily construct a [`Client`].
63
68
pub struct ClientBuilder {
64
69
url : Uri ,
@@ -68,6 +73,7 @@ pub struct ClientBuilder {
68
73
last_event_id : Option < String > ,
69
74
method : String ,
70
75
body : Option < String > ,
76
+ max_redirects : Option < u32 > ,
71
77
}
72
78
73
79
impl ClientBuilder {
@@ -88,6 +94,7 @@ impl ClientBuilder {
88
94
read_timeout : None ,
89
95
last_event_id : None ,
90
96
method : String :: from ( "GET" ) ,
97
+ max_redirects : None ,
91
98
body : None ,
92
99
} )
93
100
}
@@ -137,6 +144,14 @@ impl ClientBuilder {
137
144
self
138
145
}
139
146
147
+ /// Customize the client's following behavior when served a redirect.
148
+ /// To disable following redirects, pass `0`.
149
+ /// By default, the limit is [`DEFAULT_REDIRECT_LIMIT`].
150
+ pub fn redirect_limit ( mut self , limit : u32 ) -> ClientBuilder {
151
+ self . max_redirects = Some ( limit) ;
152
+ self
153
+ }
154
+
140
155
/// Build with a specific client connector.
141
156
pub fn build_with_conn < C > ( self , conn : C ) -> impl Client
142
157
where
@@ -158,6 +173,7 @@ impl ClientBuilder {
158
173
method : self . method ,
159
174
body : self . body ,
160
175
reconnect_opts : self . reconnect_opts ,
176
+ max_redirects : self . max_redirects . unwrap_or ( DEFAULT_REDIRECT_LIMIT ) ,
161
177
} ,
162
178
last_event_id : self . last_event_id ,
163
179
}
@@ -188,6 +204,7 @@ impl ClientBuilder {
188
204
method : self . method ,
189
205
body : self . body ,
190
206
reconnect_opts : self . reconnect_opts ,
207
+ max_redirects : self . max_redirects . unwrap_or ( DEFAULT_REDIRECT_LIMIT ) ,
191
208
} ,
192
209
last_event_id : self . last_event_id ,
193
210
}
@@ -201,6 +218,7 @@ struct RequestProps {
201
218
method : String ,
202
219
body : Option < String > ,
203
220
reconnect_opts : ReconnectOptions ,
221
+ max_redirects : u32 ,
204
222
}
205
223
206
224
/// A client implementation that connects to a server using the Server-Sent Events protocol
@@ -243,6 +261,7 @@ enum State {
243
261
} ,
244
262
Connected ( #[ pin] hyper:: Body ) ,
245
263
WaitingToReconnect ( #[ pin] Sleep ) ,
264
+ FollowingRedirect ( Option < HeaderValue > ) ,
246
265
StreamClosed ,
247
266
}
248
267
@@ -254,6 +273,7 @@ impl State {
254
273
State :: Connecting { retry : true , .. } => "connecting(retry)" ,
255
274
State :: Connected ( _) => "connected" ,
256
275
State :: WaitingToReconnect ( _) => "waiting-to-reconnect" ,
276
+ State :: FollowingRedirect ( _) => "following-redirect" ,
257
277
State :: StreamClosed => "closed" ,
258
278
}
259
279
}
@@ -273,6 +293,8 @@ pub struct ReconnectingRequest<C> {
273
293
#[ pin]
274
294
state : State ,
275
295
next_reconnect_delay : Duration ,
296
+ current_url : Uri ,
297
+ redirect_count : u32 ,
276
298
event_parser : EventParser ,
277
299
last_event_id : Option < String > ,
278
300
}
@@ -284,11 +306,14 @@ impl<C> ReconnectingRequest<C> {
284
306
last_event_id : Option < String > ,
285
307
) -> ReconnectingRequest < C > {
286
308
let reconnect_delay = props. reconnect_opts . delay ;
309
+ let url = props. url . clone ( ) ;
287
310
ReconnectingRequest {
288
311
props,
289
312
http,
290
313
state : State :: New ,
291
314
next_reconnect_delay : reconnect_delay,
315
+ redirect_count : 0 ,
316
+ current_url : url,
292
317
event_parser : EventParser :: new ( ) ,
293
318
last_event_id,
294
319
}
@@ -300,7 +325,7 @@ impl<C> ReconnectingRequest<C> {
300
325
{
301
326
let mut request_builder = Request :: builder ( )
302
327
. method ( self . props . method . as_str ( ) )
303
- . uri ( & self . props . url ) ;
328
+ . uri ( & self . current_url ) ;
304
329
305
330
for ( name, value) in & self . props . headers {
306
331
request_builder = request_builder. header ( name, value) ;
@@ -343,6 +368,21 @@ impl<C> ReconnectingRequest<C> {
343
368
let this = self . project ( ) ;
344
369
mem:: swap ( this. next_reconnect_delay , & mut delay) ;
345
370
}
371
+
372
+ fn reset_redirects ( self : Pin < & mut Self > ) {
373
+ let url = self . props . url . clone ( ) ;
374
+ let this = self . project ( ) ;
375
+ * this. current_url = url;
376
+ * this. redirect_count = 0 ;
377
+ }
378
+
379
+ fn increment_redirect_counter ( self : Pin < & mut Self > ) -> bool {
380
+ if self . redirect_count == self . props . max_redirects {
381
+ return false ;
382
+ }
383
+ * self . project ( ) . redirect_count += 1 ;
384
+ true
385
+ }
346
386
}
347
387
348
388
impl < C > Stream for ReconnectingRequest < C >
@@ -400,16 +440,39 @@ where
400
440
Ok ( resp) => {
401
441
debug ! ( "HTTP response: {:#?}" , resp) ;
402
442
403
- if !resp. status ( ) . is_success ( ) {
404
- self . as_mut ( ) . project ( ) . state . set ( State :: New ) ;
405
- return Poll :: Ready ( Some ( Err ( Error :: HttpRequest ( resp. status ( ) ) ) ) ) ;
443
+ if resp. status ( ) . is_success ( ) {
444
+ self . as_mut ( ) . reset_backoff ( ) ;
445
+ self . as_mut ( ) . reset_redirects ( ) ;
446
+ self . as_mut ( )
447
+ . project ( )
448
+ . state
449
+ . set ( State :: Connected ( resp. into_body ( ) ) ) ;
450
+ continue ;
406
451
}
407
452
408
- self . as_mut ( ) . reset_backoff ( ) ;
409
- self . as_mut ( )
410
- . project ( )
411
- . state
412
- . set ( State :: Connected ( resp. into_body ( ) ) ) ;
453
+ if resp. status ( ) == 301 || resp. status ( ) == 307 {
454
+ debug ! ( "got redirected ({})" , resp. status( ) ) ;
455
+
456
+ if self . as_mut ( ) . increment_redirect_counter ( ) {
457
+ debug ! ( "following redirect {}" , self . redirect_count) ;
458
+
459
+ self . as_mut ( ) . project ( ) . state . set ( State :: FollowingRedirect (
460
+ resp. headers ( ) . get ( hyper:: header:: LOCATION ) . cloned ( ) ,
461
+ ) ) ;
462
+ continue ;
463
+ } else {
464
+ debug ! ( "redirect limit reached ({})" , self . props. max_redirects) ;
465
+
466
+ self . as_mut ( ) . project ( ) . state . set ( State :: StreamClosed ) ;
467
+ return Poll :: Ready ( Some ( Err ( Error :: MaxRedirectLimitReached (
468
+ self . props . max_redirects ,
469
+ ) ) ) ) ;
470
+ }
471
+ }
472
+
473
+ self . as_mut ( ) . reset_redirects ( ) ;
474
+ self . as_mut ( ) . project ( ) . state . set ( State :: New ) ;
475
+ return Poll :: Ready ( Some ( Err ( Error :: UnexpectedResponse ( resp. status ( ) ) ) ) ) ;
413
476
}
414
477
Err ( e) => {
415
478
// This seems basically impossible. AFAIK we can only get this way if we
@@ -426,6 +489,16 @@ where
426
489
. set ( State :: WaitingToReconnect ( delay ( duration, "retrying" ) ) )
427
490
}
428
491
} ,
492
+ StateProj :: FollowingRedirect ( maybe_header) => match uri_from_header ( maybe_header) {
493
+ Ok ( uri) => {
494
+ * self . as_mut ( ) . project ( ) . current_url = uri;
495
+ self . as_mut ( ) . project ( ) . state . set ( State :: New ) ;
496
+ }
497
+ Err ( e) => {
498
+ self . as_mut ( ) . project ( ) . state . set ( State :: StreamClosed ) ;
499
+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
500
+ }
501
+ } ,
429
502
StateProj :: Connected ( body) => match ready ! ( body. poll_data( cx) ) {
430
503
Some ( Ok ( result) ) => {
431
504
this. event_parser . process_bytes ( result) ?;
@@ -473,6 +546,23 @@ where
473
546
}
474
547
}
475
548
549
+ fn uri_from_header ( maybe_header : & Option < HeaderValue > ) -> Result < Uri > {
550
+ let header = maybe_header. as_ref ( ) . ok_or_else ( || {
551
+ Error :: MalformedLocationHeader ( Box :: new ( std:: io:: Error :: new (
552
+ ErrorKind :: NotFound ,
553
+ "missing Location header" ,
554
+ ) ) )
555
+ } ) ?;
556
+
557
+ let header_string = header
558
+ . to_str ( )
559
+ . map_err ( |e| Error :: MalformedLocationHeader ( Box :: new ( e) ) ) ?;
560
+
561
+ header_string
562
+ . parse :: < Uri > ( )
563
+ . map_err ( |e| Error :: MalformedLocationHeader ( Box :: new ( e) ) )
564
+ }
565
+
476
566
fn delay ( dur : Duration , description : & str ) -> Sleep {
477
567
info ! ( "Waiting {:?} before {}" , dur, description) ;
478
568
tokio:: time:: sleep ( dur)
0 commit comments