diff --git a/Cargo.lock b/Cargo.lock index ccfa7e2..4fedd40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -727,6 +727,7 @@ dependencies = [ "graphql_client", "lazy_static", "mockall", + "rand 0.9.1", "reqwest 0.12.15", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 3334be5..8417e24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,3 +24,4 @@ thiserror = "2.0.12" tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] } futures = "0.3.31" +rand = "0.9.1" diff --git a/src/bot_handler/callbacks/toggle_label.rs b/src/bot_handler/callbacks/toggle_label.rs index f97c29e..8a5d9df 100644 --- a/src/bot_handler/callbacks/toggle_label.rs +++ b/src/bot_handler/callbacks/toggle_label.rs @@ -18,10 +18,11 @@ pub async fn handle(ctx: Context<'_>, label_name: &str, label_page: usize) -> Bo let (repo_id, from_page) = match dialogue_state { Some(CommandState::ViewingRepoLabels { repo_id, from_page }) => (repo_id, from_page), - _ => + _ => { return Err(BotHandlerError::InvalidInput( "Invalid state: expected ViewingRepoLabels".to_string(), - )), + )); + } }; let repo = diff --git a/src/config.rs b/src/config.rs index c902a28..b718c4a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,6 +6,7 @@ const DEFAULT_POLL_INTERVAL: u64 = 10; const DEFAULT_REPOS_PER_USER: usize = 20; const DEFAULT_LABELS_PER_REPO: usize = 10; const DEFAULT_MAX_CONCURRENCY: usize = 10; +const DEFAULT_RATE_LIMIT_THRESHOLD: u64 = 10; /// Represents the application configuration. #[derive(Debug)] @@ -26,6 +27,8 @@ pub struct Config { pub max_labels_per_repo: usize, /// The maximum number of concurrent requests to make to the GitHub API. pub max_concurrency: usize, + /// The threshold before the bot should pause operations. + pub rate_limit_threshold: u64, } impl Config { @@ -54,6 +57,10 @@ impl Config { .ok() .and_then(|v| v.parse().ok()) .unwrap_or(DEFAULT_MAX_CONCURRENCY), + rate_limit_threshold: env::var("RATE_LIMIT_THRESHOLD") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(DEFAULT_RATE_LIMIT_THRESHOLD), }) } } diff --git a/src/github/mod.rs b/src/github/mod.rs index fa0f6b5..ebfd63a 100644 --- a/src/github/mod.rs +++ b/src/github/mod.rs @@ -2,17 +2,29 @@ #[cfg(test)] mod tests; -use std::{collections::HashSet, time::Duration}; +use std::{ + collections::HashSet, + sync::Arc, + time::{Duration, Instant}, +}; use async_trait::async_trait; use backoff::{Error as BackoffError, ExponentialBackoff, future::retry}; use graphql_client::{GraphQLQuery, Response}; use mockall::automock; +use rand::{Rng, rng}; use reqwest::{ Client, header::{AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT}, }; use thiserror::Error; +use tokio::sync::Mutex; + +#[derive(Debug)] +struct RateLimitState { + remaining: u32, + reset_at: Instant, +} /// Represents errors that can occur when interacting with the GitHub API. #[derive(Debug, Error)] @@ -43,6 +55,10 @@ pub enum GithubError { /// An error indicating that the request was not authorized. #[error("GitHub authentication failed")] Unauthorized, + + /// An error indicating that a required header could not be parsed. + #[error("Failed to parse header: {0}")] + HeaderError(String), } // Helper function to check if a GraphQL error is retryable @@ -117,11 +133,17 @@ pub struct Labels; pub struct DefaultGithubClient { client: Client, graphql_url: String, + rate_limit: Arc>, + rate_limit_threshold: u64, } impl DefaultGithubClient { /// Creates a new `DefaultGithubClient`. - pub fn new(github_token: &str, graphql_url: &str) -> Result { + pub fn new( + github_token: &str, + graphql_url: &str, + rate_limit_threshold: u64, + ) -> Result { // Build the HTTP client with the GitHub token. let mut headers = HeaderMap::new(); @@ -129,10 +151,15 @@ impl DefaultGithubClient { headers.insert(USER_AGENT, HeaderValue::from_static("github-activity-rs")); let client = reqwest::Client::builder().default_headers(headers).build()?; - + let initial_state = RateLimitState { remaining: u32::MAX, reset_at: Instant::now() }; tracing::debug!("HTTP client built successfully."); - Ok(Self { client, graphql_url: graphql_url.to_string() }) + Ok(Self { + client, + graphql_url: graphql_url.to_string(), + rate_limit: Arc::new(Mutex::new(initial_state)), + rate_limit_threshold, + }) } /// Re-usable configuration for exponential backoff. @@ -158,6 +185,9 @@ impl DefaultGithubClient { { // closure that Backoff expects let operation = || async { + // 0. Rate limit guard + self.rate_limit_guard().await; + // 1. Build the request let request_body = Q::build_query(variables.clone()); @@ -170,7 +200,13 @@ impl DefaultGithubClient { }, )?; - // 3. HTTP-status check + //3 Update rate limit state from headers + if let Err(e) = self.update_rate_limit_from_headers(resp.headers()).await { + // Option A: warn and continue + tracing::warn!("Could not update rate-limit info: {}", e); + } + + // 4. HTTP-status check if !resp.status().is_success() { let status = resp.status(); let text = resp.text().await.unwrap_or_else(|e| { @@ -214,7 +250,7 @@ impl DefaultGithubClient { return Err(be); } - // 4. Parse JSON + // 5. Parse JSON let body: Response = resp.json().await.map_err(|e| { tracing::warn!("Failed to parse JSON: {e}. Retrying..."); BackoffError::transient(GithubError::GraphQLApiError(format!( @@ -222,7 +258,7 @@ impl DefaultGithubClient { ))) })?; - // 5. GraphQL errors? + // 6. GraphQL errors? if let Some(errors) = &body.errors { let is_rate_limit_error = errors.iter().any(|e| { e.message.to_lowercase().contains("rate limit") || is_retryable_graphql_error(e) @@ -239,7 +275,7 @@ impl DefaultGithubClient { } } - // 6. Unwrap the data or permanent-fail + // 7. Unwrap the data or permanent-fail body.data.ok_or_else(|| { tracing::error!("GraphQL response had no data field; permanent failure"); BackoffError::permanent(GithubError::GraphQLApiError( @@ -251,6 +287,80 @@ impl DefaultGithubClient { // kick off the retry loop retry(Self::backoff_config(), operation).await } + + /// Rate limit guard that sleeps until the rate limit resets if we're close + /// to the threshold. + async fn rate_limit_guard(&self) { + let (remaining, reset_at) = { + let state = self.rate_limit.lock().await; + (state.remaining, state.reset_at) + }; + + // define a safety threshold + let threshold = self.rate_limit_threshold as u32; + if remaining <= threshold { + let now = Instant::now(); + if now < reset_at { + let wait = reset_at - now; + tracing::info!( + "Approaching rate limit ({} left). Sleeping {:?} until reset...", + remaining, + wait + ); + + // Sleep until the rate limit resets + //added a jitter to avoid thundering herd problem + let max_jitter = wait.as_millis() as u64 / 10; + let jitter_ms = rng().random_range(0..=max_jitter); + tokio::time::sleep(wait + Duration::from_millis(jitter_ms)).await; + } + } + } + + /// Update the rate limit state from the response headers. + async fn update_rate_limit_from_headers(&self, headers: &HeaderMap) -> Result<(), GithubError> { + // Names are case-insensitive in HeaderMap + let rem_val = headers.get("X-RateLimit-Remaining").ok_or_else(|| { + let msg = "Missing X-RateLimit-Remaining header".to_string(); + tracing::error!("{}", msg); + GithubError::HeaderError(msg) + })?; + let rem_str = rem_val.to_str().map_err(|e| { + let msg = format!("Invalid X-RateLimit-Remaining value: {e}"); + tracing::error!("{}", msg); + GithubError::HeaderError(msg) + })?; + let remaining = rem_str.parse::().map_err(|e| { + let msg = format!("Cannot parse remaining as u32: {e}"); + tracing::error!("{}", msg); + GithubError::HeaderError(msg) + })?; + + let reset_val = headers.get("X-RateLimit-Reset").ok_or_else(|| { + let msg = "Missing X-RateLimit-Reset header".to_string(); + tracing::error!("{}", msg); + GithubError::HeaderError(msg) + })?; + let reset_str = reset_val.to_str().map_err(|e| { + let msg = format!("Invalid X-RateLimit-Reset value: {e}"); + tracing::error!("{}", msg); + GithubError::HeaderError(msg) + })?; + let reset_unix = reset_str.parse::().map_err(|e| { + let msg = format!("Cannot parse reset timestamp as u64: {e}"); + tracing::error!("{}", msg); + GithubError::HeaderError(msg) + })?; + + // All good — update the shared state + let mut state = self.rate_limit.lock().await; + state.remaining = remaining; + let reset_in = reset_unix.saturating_sub(chrono::Utc::now().timestamp() as u64); + state.reset_at = Instant::now() + Duration::from_secs(reset_in); + + tracing::debug!("Rate limit updated: {} remaining, resets in {}s", remaining, reset_in); + Ok(()) + } } #[async_trait] diff --git a/src/github/tests.rs b/src/github/tests.rs index e533ee3..ad33231 100644 --- a/src/github/tests.rs +++ b/src/github/tests.rs @@ -1,10 +1,9 @@ use std::collections::HashMap; use super::*; - #[test] fn test_new_github_client() { - let client = DefaultGithubClient::new("test_token", "https://api.github.com/graphql"); + let client = DefaultGithubClient::new("test_token", "https://api.github.com/graphql", 10); assert!(client.is_ok()); } @@ -37,3 +36,165 @@ fn test_is_not_retryable_graphql_error() { assert!(!is_retryable_graphql_error(&error)); } + +#[tokio::test] +async fn test_update_rate_limit_from_headers() { + let client = + DefaultGithubClient::new("fake", "https://api.github.com/graphql", 5).expect("client init"); + + // Build fake headers with remaining=3, reset in 60s + let mut headers = HeaderMap::new(); + headers.insert("X-RateLimit-Remaining", HeaderValue::from_static("3")); + let reset_ts = (chrono::Utc::now().timestamp() as u64) + 60; + headers.insert("X-RateLimit-Reset", HeaderValue::from_str(&reset_ts.to_string()).unwrap()); + + client.update_rate_limit_from_headers(&headers).await; + + let state = client.rate_limit.lock().await; + assert_eq!(state.remaining, 3); + // We expect reset_at ≈ now + 60s (within a small delta) + let diff = state.reset_at.checked_duration_since(Instant::now()).unwrap(); + assert!(diff >= Duration::from_secs(59) && diff <= Duration::from_secs(61)); +} + +#[tokio::test(flavor = "multi_thread")] +async fn rate_limit_guard_sleeps_when_below_threshold_and_before_reset() { + // -------- Arrange -------- + let threshold = 5; + let client = + DefaultGithubClient::new("fake_token", "https://example.com/graphql", threshold as u64) + .expect("client"); + + const WAIT_MS: u64 = 40; + { + let mut state = client.rate_limit.lock().await; + state.remaining = threshold; + state.reset_at = Instant::now() + Duration::from_millis(WAIT_MS); + } + + let expected_min = Duration::from_millis(WAIT_MS); + let expected_max = Duration::from_millis(WAIT_MS + WAIT_MS / 10); + let fudge = Duration::from_millis(10); + + // -------- Act -------- + let start = Instant::now(); + client.rate_limit_guard().await; + let elapsed = start.elapsed(); + + // -------- Assert -------- + assert!(elapsed >= expected_min, "Guard returned too fast: {:?} < {:?}", elapsed, expected_min); + assert!( + elapsed <= expected_max + fudge, + "Guard slept too long: {:?} > {:?}", + elapsed, + expected_max + fudge + ); +} + +/// Helper: set the shared rate-limit state so the guard will sleep `wait_ms` +/// plus jitter. +async fn prime_state(client: &DefaultGithubClient, remaining: u32, wait_ms: u64) { + let mut s = client.rate_limit.lock().await; + s.remaining = remaining; + s.reset_at = Instant::now() + Duration::from_millis(wait_ms); +} + +/// Helper: run the guard once and return how long it actually waited. +async fn measure_sleep(client: &DefaultGithubClient) -> Duration { + let start = Instant::now(); + client.rate_limit_guard().await; + start.elapsed() +} + +/// 1) Always inside [wait, wait + 10%] (with a little fudge for scheduler +/// noise) +#[tokio::test(flavor = "multi_thread")] +async fn jitter_is_within_bounds() { + const THRESHOLD: u64 = 5; + const WAIT_MS: u64 = 50; + const FUDGE_MS: u64 = 8; + + let client = + DefaultGithubClient::new("fake", "https://example/graphql", THRESHOLD).expect("client"); + + // Force a sleep path + prime_state(&client, THRESHOLD as u32, WAIT_MS).await; + + let expected_min = Duration::from_millis(WAIT_MS); + let expected_max = Duration::from_millis(WAIT_MS + WAIT_MS / 10); // 10% jitter + let fudge = Duration::from_millis(FUDGE_MS); + + let elapsed = measure_sleep(&client).await; + + assert!(elapsed >= expected_min, "Returned too fast: {:?} < {:?}", elapsed, expected_min); + assert!( + elapsed <= expected_max + fudge, + "Slept too long: {:?} > {:?}", + elapsed, + expected_max + fudge + ); +} + +/// 2) We actually *get* jitter sometimes (i.e., not always exactly WAIT_MS). +/// Run the guard a bunch of times and check the spread of durations. +#[tokio::test(flavor = "multi_thread")] +async fn jitter_varies_across_runs() { + const THRESHOLD: u64 = 3; + const WAIT_MS: u64 = 40; + const RUNS: usize = 20; + const FUDGE_MS: u64 = 8; + + let client = + DefaultGithubClient::new("fake", "https://example/graphql", THRESHOLD).expect("client"); + + let mut samples = Vec::with_capacity(RUNS); + + for _ in 0..RUNS { + prime_state(&client, THRESHOLD as u32, WAIT_MS).await; + samples.push(measure_sleep(&client).await); + } + + let min = samples.iter().min().cloned().unwrap(); + let max = samples.iter().max().cloned().unwrap(); + + let base = Duration::from_millis(WAIT_MS); + let jitter_span = Duration::from_millis(WAIT_MS / 10); + + // Same bounds check as safety + for (i, dur) in samples.iter().enumerate() { + assert!(*dur >= base, "Run {i}: {:?} < base {:?}", dur, base); + assert!( + *dur <= base + jitter_span + Duration::from_millis(FUDGE_MS), + "Run {i}: {:?} > upper bound {:?}", + dur, + base + jitter_span + Duration::from_millis(FUDGE_MS) + ); + } + + // And now confirm we saw at least ~some spread (non‑deterministic, so we only + ///require >1ms spread). + assert!( + max > min + Duration::from_millis(1), + "Jitter didn't vary enough: min={:?}, max={:?}", + min, + max + ); +} + +/// 3) When wait == 0, max_jitter == 0 → no sleep at all. +#[tokio::test(flavor = "multi_thread")] +async fn no_jitter_when_wait_is_zero() { + const THRESHOLD: u64 = 1; + let client = + DefaultGithubClient::new("fake", "https://example/graphql", THRESHOLD).expect("client"); + + // Force path where remaining <= threshold but reset_at == now + let mut s = client.rate_limit.lock().await; + s.remaining = THRESHOLD as u32; + s.reset_at = Instant::now(); // so wait = 0 + drop(s); + + // Should return basically immediately + let elapsed = measure_sleep(&client).await; + assert!(elapsed < Duration::from_millis(2), "Guard unexpectedly slept: {:?}", elapsed); +} diff --git a/src/lib.rs b/src/lib.rs index db0cd03..2c96fae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,6 +45,7 @@ pub async fn run() -> Result<(), Box> { let github_client = Arc::new(github::DefaultGithubClient::new( &config.github_token, &config.github_graphql_url, + config.rate_limit_threshold, )?); let messaging_service = Arc::new(TelegramMessagingService::new(bot.clone())); diff --git a/src/poller/mod.rs b/src/poller/mod.rs index 0fdea23..a8ff316 100644 --- a/src/poller/mod.rs +++ b/src/poller/mod.rs @@ -199,6 +199,15 @@ impl GithubPoller { ); return Err(PollerError::Github(github_error)); } + GithubError::HeaderError(msg) => { + tracing::warn!( + "Could not parse rate limit headers for repo {} (chat {}): {}. \ + Skipping this repo for this cycle.", + repo.name_with_owner, + chat_id, + msg + ); + } }, unexpected_error => { tracing::error!(