Skip to content

Commit db753ce

Browse files
authored
Handle redirects with new FollowRedirects state (#29)
* Implement FollowRedirect state to handle 307 and 301 redirects * Add redirect_limit configuration option
1 parent 257788b commit db753ce

File tree

3 files changed

+105
-13
lines changed

3 files changed

+105
-13
lines changed

Makefile

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
TEMP_TEST_OUTPUT=/tmp/contract-test-service.log
2-
SKIPFLAGS = -skip 'HTTP behavior/client follows 301 redirect' -skip 'HTTP behavior/client follows 307 redirect'
3-
42

53
build-contract-tests:
64
@cargo build

eventsource-client/src/client.rs

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use std::{
1717
boxed,
1818
fmt::{self, Debug, Display, Formatter},
1919
future::Future,
20+
io::ErrorKind,
2021
mem,
2122
pin::Pin,
2223
str::FromStr,
@@ -59,6 +60,10 @@ pub trait Client: Send + Sync + private::Sealed {
5960
* TODO specify list of stati to not retry (e.g. 204)
6061
*/
6162

63+
/// Maximum amount of redirects that the client will follow before
64+
/// giving up, if not overridden via [ClientBuilder::redirect_limit].
65+
pub const DEFAULT_REDIRECT_LIMIT: u32 = 16;
66+
6267
/// ClientBuilder provides a series of builder methods to easily construct a [`Client`].
6368
pub struct ClientBuilder {
6469
url: Uri,
@@ -68,6 +73,7 @@ pub struct ClientBuilder {
6873
last_event_id: Option<String>,
6974
method: String,
7075
body: Option<String>,
76+
max_redirects: Option<u32>,
7177
}
7278

7379
impl ClientBuilder {
@@ -88,6 +94,7 @@ impl ClientBuilder {
8894
read_timeout: None,
8995
last_event_id: None,
9096
method: String::from("GET"),
97+
max_redirects: None,
9198
body: None,
9299
})
93100
}
@@ -137,6 +144,14 @@ impl ClientBuilder {
137144
self
138145
}
139146

147+
/// Customize the client's following behavior when served a redirect.
148+
/// To disable following redirects, pass `0`.
149+
/// By default, the limit is [`DEFAULT_REDIRECT_LIMIT`].
150+
pub fn redirect_limit(mut self, limit: u32) -> ClientBuilder {
151+
self.max_redirects = Some(limit);
152+
self
153+
}
154+
140155
/// Build with a specific client connector.
141156
pub fn build_with_conn<C>(self, conn: C) -> impl Client
142157
where
@@ -158,6 +173,7 @@ impl ClientBuilder {
158173
method: self.method,
159174
body: self.body,
160175
reconnect_opts: self.reconnect_opts,
176+
max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
161177
},
162178
last_event_id: self.last_event_id,
163179
}
@@ -188,6 +204,7 @@ impl ClientBuilder {
188204
method: self.method,
189205
body: self.body,
190206
reconnect_opts: self.reconnect_opts,
207+
max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
191208
},
192209
last_event_id: self.last_event_id,
193210
}
@@ -201,6 +218,7 @@ struct RequestProps {
201218
method: String,
202219
body: Option<String>,
203220
reconnect_opts: ReconnectOptions,
221+
max_redirects: u32,
204222
}
205223

206224
/// A client implementation that connects to a server using the Server-Sent Events protocol
@@ -243,6 +261,7 @@ enum State {
243261
},
244262
Connected(#[pin] hyper::Body),
245263
WaitingToReconnect(#[pin] Sleep),
264+
FollowingRedirect(Option<HeaderValue>),
246265
StreamClosed,
247266
}
248267

@@ -254,6 +273,7 @@ impl State {
254273
State::Connecting { retry: true, .. } => "connecting(retry)",
255274
State::Connected(_) => "connected",
256275
State::WaitingToReconnect(_) => "waiting-to-reconnect",
276+
State::FollowingRedirect(_) => "following-redirect",
257277
State::StreamClosed => "closed",
258278
}
259279
}
@@ -273,6 +293,8 @@ pub struct ReconnectingRequest<C> {
273293
#[pin]
274294
state: State,
275295
next_reconnect_delay: Duration,
296+
current_url: Uri,
297+
redirect_count: u32,
276298
event_parser: EventParser,
277299
last_event_id: Option<String>,
278300
}
@@ -284,11 +306,14 @@ impl<C> ReconnectingRequest<C> {
284306
last_event_id: Option<String>,
285307
) -> ReconnectingRequest<C> {
286308
let reconnect_delay = props.reconnect_opts.delay;
309+
let url = props.url.clone();
287310
ReconnectingRequest {
288311
props,
289312
http,
290313
state: State::New,
291314
next_reconnect_delay: reconnect_delay,
315+
redirect_count: 0,
316+
current_url: url,
292317
event_parser: EventParser::new(),
293318
last_event_id,
294319
}
@@ -300,7 +325,7 @@ impl<C> ReconnectingRequest<C> {
300325
{
301326
let mut request_builder = Request::builder()
302327
.method(self.props.method.as_str())
303-
.uri(&self.props.url);
328+
.uri(&self.current_url);
304329

305330
for (name, value) in &self.props.headers {
306331
request_builder = request_builder.header(name, value);
@@ -343,6 +368,21 @@ impl<C> ReconnectingRequest<C> {
343368
let this = self.project();
344369
mem::swap(this.next_reconnect_delay, &mut delay);
345370
}
371+
372+
fn reset_redirects(self: Pin<&mut Self>) {
373+
let url = self.props.url.clone();
374+
let this = self.project();
375+
*this.current_url = url;
376+
*this.redirect_count = 0;
377+
}
378+
379+
fn increment_redirect_counter(self: Pin<&mut Self>) -> bool {
380+
if self.redirect_count == self.props.max_redirects {
381+
return false;
382+
}
383+
*self.project().redirect_count += 1;
384+
true
385+
}
346386
}
347387

348388
impl<C> Stream for ReconnectingRequest<C>
@@ -400,16 +440,39 @@ where
400440
Ok(resp) => {
401441
debug!("HTTP response: {:#?}", resp);
402442

403-
if !resp.status().is_success() {
404-
self.as_mut().project().state.set(State::New);
405-
return Poll::Ready(Some(Err(Error::HttpRequest(resp.status()))));
443+
if resp.status().is_success() {
444+
self.as_mut().reset_backoff();
445+
self.as_mut().reset_redirects();
446+
self.as_mut()
447+
.project()
448+
.state
449+
.set(State::Connected(resp.into_body()));
450+
continue;
406451
}
407452

408-
self.as_mut().reset_backoff();
409-
self.as_mut()
410-
.project()
411-
.state
412-
.set(State::Connected(resp.into_body()));
453+
if resp.status() == 301 || resp.status() == 307 {
454+
debug!("got redirected ({})", resp.status());
455+
456+
if self.as_mut().increment_redirect_counter() {
457+
debug!("following redirect {}", self.redirect_count);
458+
459+
self.as_mut().project().state.set(State::FollowingRedirect(
460+
resp.headers().get(hyper::header::LOCATION).cloned(),
461+
));
462+
continue;
463+
} else {
464+
debug!("redirect limit reached ({})", self.props.max_redirects);
465+
466+
self.as_mut().project().state.set(State::StreamClosed);
467+
return Poll::Ready(Some(Err(Error::MaxRedirectLimitReached(
468+
self.props.max_redirects,
469+
))));
470+
}
471+
}
472+
473+
self.as_mut().reset_redirects();
474+
self.as_mut().project().state.set(State::New);
475+
return Poll::Ready(Some(Err(Error::UnexpectedResponse(resp.status()))));
413476
}
414477
Err(e) => {
415478
// This seems basically impossible. AFAIK we can only get this way if we
@@ -426,6 +489,16 @@ where
426489
.set(State::WaitingToReconnect(delay(duration, "retrying")))
427490
}
428491
},
492+
StateProj::FollowingRedirect(maybe_header) => match uri_from_header(maybe_header) {
493+
Ok(uri) => {
494+
*self.as_mut().project().current_url = uri;
495+
self.as_mut().project().state.set(State::New);
496+
}
497+
Err(e) => {
498+
self.as_mut().project().state.set(State::StreamClosed);
499+
return Poll::Ready(Some(Err(e)));
500+
}
501+
},
429502
StateProj::Connected(body) => match ready!(body.poll_data(cx)) {
430503
Some(Ok(result)) => {
431504
this.event_parser.process_bytes(result)?;
@@ -473,6 +546,23 @@ where
473546
}
474547
}
475548

549+
fn uri_from_header(maybe_header: &Option<HeaderValue>) -> Result<Uri> {
550+
let header = maybe_header.as_ref().ok_or_else(|| {
551+
Error::MalformedLocationHeader(Box::new(std::io::Error::new(
552+
ErrorKind::NotFound,
553+
"missing Location header",
554+
)))
555+
})?;
556+
557+
let header_string = header
558+
.to_str()
559+
.map_err(|e| Error::MalformedLocationHeader(Box::new(e)))?;
560+
561+
header_string
562+
.parse::<Uri>()
563+
.map_err(|e| Error::MalformedLocationHeader(Box::new(e)))
564+
}
565+
476566
fn delay(dur: Duration, description: &str) -> Sleep {
477567
info!("Waiting {:?} before {}", dur, description);
478568
tokio::time::sleep(dur)

eventsource-client/src/error.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ pub enum Error {
77
StreamClosed,
88
/// An invalid request parameter
99
InvalidParameter(Box<dyn std::error::Error + Send + 'static>),
10-
/// The HTTP request failed.
11-
HttpRequest(StatusCode),
10+
/// The HTTP response could not be handled.
11+
UnexpectedResponse(StatusCode),
1212
/// An error reading from the HTTP response body.
1313
HttpStream(Box<dyn std::error::Error + Send + 'static>),
1414
/// The HTTP response stream ended
@@ -19,6 +19,10 @@ pub enum Error {
1919
/// Encountered a line not conforming to the SSE protocol.
2020
InvalidLine(String),
2121
InvalidEvent,
22+
/// Encountered a malformed Location header.
23+
MalformedLocationHeader(Box<dyn std::error::Error + Send + 'static>),
24+
/// Reached maximum redirect limit after encountering Location headers.
25+
MaxRedirectLimitReached(u32),
2226
/// An unexpected failure occurred.
2327
Unexpected(Box<dyn std::error::Error + Send + 'static>),
2428
}

0 commit comments

Comments
 (0)