22#[ cfg( test) ]
33mod tests;
44
5- use std:: { collections:: HashSet , time:: Duration } ;
5+ use std:: {
6+ collections:: HashSet ,
7+ sync:: Arc ,
8+ time:: { Duration , Instant } ,
9+ } ;
610
711use async_trait:: async_trait;
812use backoff:: { Error as BackoffError , ExponentialBackoff , future:: retry} ;
913use graphql_client:: { GraphQLQuery , Response } ;
1014use mockall:: automock;
15+ use rand:: { Rng , rng} ;
1116use reqwest:: {
1217 Client ,
1318 header:: { AUTHORIZATION , HeaderMap , HeaderValue , USER_AGENT } ,
1419} ;
1520use thiserror:: Error ;
21+ use tokio:: sync:: Mutex ;
22+
23+ #[ derive( Debug ) ]
24+ struct RateLimitState {
25+ remaining : u32 ,
26+ reset_at : Instant ,
27+ }
1628
1729/// Represents errors that can occur when interacting with the GitHub API.
1830#[ derive( Debug , Error ) ]
@@ -43,6 +55,10 @@ pub enum GithubError {
4355 /// An error indicating that the request was not authorized.
4456 #[ error( "GitHub authentication failed" ) ]
4557 Unauthorized ,
58+
59+ /// An error indicating that a required header could not be parsed.
60+ #[ error( "Failed to parse header: {0}" ) ]
61+ HeaderError ( String ) ,
4662}
4763
4864// Helper function to check if a GraphQL error is retryable
@@ -117,22 +133,33 @@ pub struct Labels;
117133pub struct DefaultGithubClient {
118134 client : Client ,
119135 graphql_url : String ,
136+ rate_limit : Arc < Mutex < RateLimitState > > ,
137+ rate_limit_threshold : u64 ,
120138}
121139
122140impl DefaultGithubClient {
123141 /// Creates a new `DefaultGithubClient`.
124- pub fn new ( github_token : & str , graphql_url : & str ) -> Result < Self , GithubError > {
142+ pub fn new (
143+ github_token : & str ,
144+ graphql_url : & str ,
145+ rate_limit_threshold : u64 ,
146+ ) -> Result < Self , GithubError > {
125147 // Build the HTTP client with the GitHub token.
126148 let mut headers = HeaderMap :: new ( ) ;
127149
128150 headers. insert ( AUTHORIZATION , HeaderValue :: from_str ( & format ! ( "Bearer {github_token}" ) ) ?) ;
129151 headers. insert ( USER_AGENT , HeaderValue :: from_static ( "github-activity-rs" ) ) ;
130152
131153 let client = reqwest:: Client :: builder ( ) . default_headers ( headers) . build ( ) ?;
132-
154+ let initial_state = RateLimitState { remaining : u32 :: MAX , reset_at : Instant :: now ( ) } ;
133155 tracing:: debug!( "HTTP client built successfully." ) ;
134156
135- Ok ( Self { client, graphql_url : graphql_url. to_string ( ) } )
157+ Ok ( Self {
158+ client,
159+ graphql_url : graphql_url. to_string ( ) ,
160+ rate_limit : Arc :: new ( Mutex :: new ( initial_state) ) ,
161+ rate_limit_threshold,
162+ } )
136163 }
137164
138165 /// Re-usable configuration for exponential backoff.
@@ -158,6 +185,9 @@ impl DefaultGithubClient {
158185 {
159186 // closure that Backoff expects
160187 let operation = || async {
188+ // 0. Rate limit guard
189+ self . rate_limit_guard ( ) . await ;
190+
161191 // 1. Build the request
162192 let request_body = Q :: build_query ( variables. clone ( ) ) ;
163193
@@ -170,7 +200,13 @@ impl DefaultGithubClient {
170200 } ,
171201 ) ?;
172202
173- // 3. HTTP-status check
203+ //3 Update rate limit state from headers
204+ if let Err ( e) = self . update_rate_limit_from_headers ( resp. headers ( ) ) . await {
205+ // Option A: warn and continue
206+ tracing:: warn!( "Could not update rate-limit info: {}" , e) ;
207+ }
208+
209+ // 4. HTTP-status check
174210 if !resp. status ( ) . is_success ( ) {
175211 let status = resp. status ( ) ;
176212 let text = resp. text ( ) . await . unwrap_or_else ( |e| {
@@ -214,15 +250,15 @@ impl DefaultGithubClient {
214250 return Err ( be) ;
215251 }
216252
217- // 4 . Parse JSON
253+ // 5 . Parse JSON
218254 let body: Response < Q :: ResponseData > = resp. json ( ) . await . map_err ( |e| {
219255 tracing:: warn!( "Failed to parse JSON: {e}. Retrying..." ) ;
220256 BackoffError :: transient ( GithubError :: GraphQLApiError ( format ! (
221257 "JSON parse error: {e}"
222258 ) ) )
223259 } ) ?;
224260
225- // 5 . GraphQL errors?
261+ // 6 . GraphQL errors?
226262 if let Some ( errors) = & body. errors {
227263 let is_rate_limit_error = errors. iter ( ) . any ( |e| {
228264 e. message . to_lowercase ( ) . contains ( "rate limit" ) || is_retryable_graphql_error ( e)
@@ -239,7 +275,7 @@ impl DefaultGithubClient {
239275 }
240276 }
241277
242- // 6 . Unwrap the data or permanent-fail
278+ // 7 . Unwrap the data or permanent-fail
243279 body. data . ok_or_else ( || {
244280 tracing:: error!( "GraphQL response had no data field; permanent failure" ) ;
245281 BackoffError :: permanent ( GithubError :: GraphQLApiError (
@@ -251,6 +287,80 @@ impl DefaultGithubClient {
251287 // kick off the retry loop
252288 retry ( Self :: backoff_config ( ) , operation) . await
253289 }
290+
291+ /// Rate limit guard that sleeps until the rate limit resets if we're close
292+ /// to the threshold.
293+ async fn rate_limit_guard ( & self ) {
294+ let ( remaining, reset_at) = {
295+ let state = self . rate_limit . lock ( ) . await ;
296+ ( state. remaining , state. reset_at )
297+ } ;
298+
299+ // define a safety threshold
300+ let threshold = self . rate_limit_threshold as u32 ;
301+ if remaining <= threshold {
302+ let now = Instant :: now ( ) ;
303+ if now < reset_at {
304+ let wait = reset_at - now;
305+ tracing:: info!(
306+ "Approaching rate limit ({} left). Sleeping {:?} until reset..." ,
307+ remaining,
308+ wait
309+ ) ;
310+
311+ // Sleep until the rate limit resets
312+ //added a jitter to avoid thundering herd problem
313+ let max_jitter = wait. as_millis ( ) as u64 / 10 ;
314+ let jitter_ms = rng ( ) . random_range ( 0 ..=max_jitter) ;
315+ tokio:: time:: sleep ( wait + Duration :: from_millis ( jitter_ms) ) . await ;
316+ }
317+ }
318+ }
319+
320+ /// Update the rate limit state from the response headers.
321+ async fn update_rate_limit_from_headers ( & self , headers : & HeaderMap ) -> Result < ( ) , GithubError > {
322+ // Names are case-insensitive in HeaderMap
323+ let rem_val = headers. get ( "X-RateLimit-Remaining" ) . ok_or_else ( || {
324+ let msg = "Missing X-RateLimit-Remaining header" . to_string ( ) ;
325+ tracing:: error!( "{}" , msg) ;
326+ GithubError :: HeaderError ( msg)
327+ } ) ?;
328+ let rem_str = rem_val. to_str ( ) . map_err ( |e| {
329+ let msg = format ! ( "Invalid X-RateLimit-Remaining value: {e}" ) ;
330+ tracing:: error!( "{}" , msg) ;
331+ GithubError :: HeaderError ( msg)
332+ } ) ?;
333+ let remaining = rem_str. parse :: < u32 > ( ) . map_err ( |e| {
334+ let msg = format ! ( "Cannot parse remaining as u32: {e}" ) ;
335+ tracing:: error!( "{}" , msg) ;
336+ GithubError :: HeaderError ( msg)
337+ } ) ?;
338+
339+ let reset_val = headers. get ( "X-RateLimit-Reset" ) . ok_or_else ( || {
340+ let msg = "Missing X-RateLimit-Reset header" . to_string ( ) ;
341+ tracing:: error!( "{}" , msg) ;
342+ GithubError :: HeaderError ( msg)
343+ } ) ?;
344+ let reset_str = reset_val. to_str ( ) . map_err ( |e| {
345+ let msg = format ! ( "Invalid X-RateLimit-Reset value: {e}" ) ;
346+ tracing:: error!( "{}" , msg) ;
347+ GithubError :: HeaderError ( msg)
348+ } ) ?;
349+ let reset_unix = reset_str. parse :: < u64 > ( ) . map_err ( |e| {
350+ let msg = format ! ( "Cannot parse reset timestamp as u64: {e}" ) ;
351+ tracing:: error!( "{}" , msg) ;
352+ GithubError :: HeaderError ( msg)
353+ } ) ?;
354+
355+ // All good — update the shared state
356+ let mut state = self . rate_limit . lock ( ) . await ;
357+ state. remaining = remaining;
358+ let reset_in = reset_unix. saturating_sub ( chrono:: Utc :: now ( ) . timestamp ( ) as u64 ) ;
359+ state. reset_at = Instant :: now ( ) + Duration :: from_secs ( reset_in) ;
360+
361+ tracing:: debug!( "Rate limit updated: {} remaining, resets in {}s" , remaining, reset_in) ;
362+ Ok ( ( ) )
363+ }
254364}
255365
256366#[ async_trait]
0 commit comments