From 0c6999189edb82c20f5268e514a115c023551d5c Mon Sep 17 00:00:00 2001 From: Jareth Gomes Date: Sat, 4 Oct 2025 02:36:45 +0000 Subject: [PATCH 1/4] chore: Add tests for cooldown module --- src/cooldown.rs | 389 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 389 insertions(+) diff --git a/src/cooldown.rs b/src/cooldown.rs index 8afba194de8..e33f12fe9dd 100644 --- a/src/cooldown.rs +++ b/src/cooldown.rs @@ -173,3 +173,392 @@ impl<'a> From<&'a serenity::Message> for CooldownContext { } } } + +#[cfg(test)] +mod test { + use ::serenity::all::{ChannelId, GuildId, UserId}; + + use super::*; + + #[test] + fn start_cooldown_triggers_guild_cooldown() { + let config = CooldownConfig { + global: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + } + + #[test] + fn cooldown_resets_after_window_expires() { + let config = CooldownConfig { + global: Some(Duration::from_secs(1)), + ..Default::default() + }; + let tracker = CooldownTracker { + global_invocation: Some(Instant::now() - Duration::from_secs(1)), + ..Default::default() + }; + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + } + + #[tokio::test] + async fn basic_global_cooldown() { + let config = CooldownConfig { + global: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx.clone(), &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + tokio::time::sleep(cooldown).await; + assert!(tracker.remaining_cooldown(ctx, &config).is_none()); + } + + #[test] + fn global_cooldown_affects_other_users() { + let config = CooldownConfig { + global: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let other_user_ctx = CooldownContext { + user_id: UserId::from(54321), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + let cooldown = tracker.remaining_cooldown(other_user_ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + } + + #[tokio::test] + async fn basic_user_cooldown() { + let config = CooldownConfig { + user: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx.clone(), &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + tokio::time::sleep(cooldown).await; + assert!(tracker.remaining_cooldown(ctx, &config).is_none()); + } + + #[test] + fn user_cooldown_does_not_affect_other_users() { + let config = CooldownConfig { + user: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let other_user_ctx = CooldownContext { + user_id: UserId::from(54321), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + assert!(tracker + .remaining_cooldown(other_user_ctx, &config) + .is_none()); + + let same_user_different_channel_ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(9876), + }; + + let cooldown = tracker.remaining_cooldown(same_user_different_channel_ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + } + + #[tokio::test] + async fn basic_channel_cooldown() { + let config = CooldownConfig { + channel: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx.clone(), &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + tokio::time::sleep(cooldown).await; + assert!(tracker.remaining_cooldown(ctx, &config).is_none()) + } + + #[test] + fn channel_cooldown_affects_one_channel() { + let config = CooldownConfig { + channel: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let other_user_ctx = CooldownContext { + user_id: UserId::from(54321), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + let cooldown = tracker.remaining_cooldown(other_user_ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let same_user_different_channel_ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(9876), + }; + + assert!(tracker + .remaining_cooldown(same_user_different_channel_ctx, &config) + .is_none()); + } + + #[tokio::test] + async fn basic_guild_cooldown() { + let config = CooldownConfig { + guild: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx.clone(), &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + tokio::time::sleep(cooldown).await; + assert!(tracker.remaining_cooldown(ctx, &config).is_none()); + } + + #[test] + fn guild_cooldown_affects_one_guild() { + let config = CooldownConfig { + guild: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let other_user_ctx = CooldownContext { + user_id: UserId::from(54321), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + let cooldown = tracker.remaining_cooldown(other_user_ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + // This is not realistic since while guild id is different, the channel id is the same, but + // this is to demostrate the guild id affects the cooldown. + let same_user_different_guild_ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(420)), + channel_id: ChannelId::from(67890), + }; + + assert!(tracker + .remaining_cooldown(same_user_different_guild_ctx, &config) + .is_none()); + } + + #[tokio::test] + async fn basic_member_cooldown() { + let config = CooldownConfig { + member: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx.clone(), &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + tokio::time::sleep(cooldown).await; + assert!(tracker.remaining_cooldown(ctx, &config).is_none()); + } + + #[test] + fn member_cooldown_affects_one_member() { + let config = CooldownConfig { + member: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let other_user_ctx = CooldownContext { + user_id: UserId::from(54321), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + assert!(tracker + .remaining_cooldown(other_user_ctx, &config) + .is_none()); + + let same_user_different_channel_ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(9876), + }; + let cooldown = tracker.remaining_cooldown(same_user_different_channel_ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + } +} From 7de210437ed4ecc1788e3abda3461cdf30fff389 Mon Sep 17 00:00:00 2001 From: Jareth Gomes Date: Fri, 3 Oct 2025 07:53:52 +0000 Subject: [PATCH 2/4] feat!: Add burstable counts to cooldowns This enables for commands to be invoked multiple times before they get put on cooldown. A cooldown cycle begins with the first instance a command is invoked and ends when the cooldown expires. When the cooldown expires, the number of invokes also resets to 0. By default, the burstable amount for each type of cooldown is set to 1, which will mimic the old behavior, only allowing for 1 command invoke per cooldown cycle. Command attribute macros that do not explicitly add these fields will inherit the default value. The reason why this change is marked as breaking is due to `CooldownConfig` having additional public fields being added. This struct was already marked as non-exhaustive, so the user should expect this struct to break. --- macros/src/command/mod.rs | 16 ++++ macros/src/command/prefix.rs | 3 +- macros/src/command/slash.rs | 3 +- macros/src/lib.rs | 5 + src/cooldown.rs | 180 +++++++++++++++++++++++++++++++---- 5 files changed, 186 insertions(+), 21 deletions(-) diff --git a/macros/src/command/mod.rs b/macros/src/command/mod.rs index 3c9d82feb1f..0e4e1f6b76b 100644 --- a/macros/src/command/mod.rs +++ b/macros/src/command/mod.rs @@ -59,6 +59,12 @@ pub struct CommandArgs { guild_cooldown: Option, channel_cooldown: Option, member_cooldown: Option, + + global_cooldown_burst: Option, + user_cooldown_burst: Option, + guild_cooldown_burst: Option, + channel_cooldown_burst: Option, + member_cooldown_burst: Option, } /// Representation of the function parameter attribute arguments @@ -429,18 +435,28 @@ fn generate_cooldown_config(args: &CommandArgs) -> proc_macro2::TokenStream { let to_seconds_path = quote::quote!(std::time::Duration::from_secs); let global_cooldown = wrap_option_and_map(args.global_cooldown, &to_seconds_path); + let global_burst = wrap_option(args.global_cooldown_burst); let user_cooldown = wrap_option_and_map(args.user_cooldown, &to_seconds_path); + let user_burst = wrap_option(args.user_cooldown_burst); let guild_cooldown = wrap_option_and_map(args.guild_cooldown, &to_seconds_path); + let guild_burst = wrap_option(args.guild_cooldown_burst); let channel_cooldown = wrap_option_and_map(args.channel_cooldown, &to_seconds_path); + let channel_burst = wrap_option(args.channel_cooldown_burst); let member_cooldown = wrap_option_and_map(args.member_cooldown, &to_seconds_path); + let member_burst = wrap_option(args.member_cooldown_burst); quote::quote!( std::sync::RwLock::new(::poise::CooldownConfig { global: #global_cooldown, + global_burst_amount: #global_burst, user: #user_cooldown, + user_burst_amount: #user_burst, guild: #guild_cooldown, + guild_burst_amount: #guild_burst, channel: #channel_cooldown, + channel_burst_amount: #channel_burst, member: #member_cooldown, + member_burst_amount: #member_burst, __non_exhaustive: () }) ) diff --git a/macros/src/command/prefix.rs b/macros/src/command/prefix.rs index d4750ae1216..bdb84961dfb 100644 --- a/macros/src/command/prefix.rs +++ b/macros/src/command/prefix.rs @@ -70,7 +70,8 @@ pub fn generate_prefix_action(inv: &Invocation) -> Result as core::ops::Deref>::deref(&cooldown_config)); } inner(ctx.into(), #( #param_idents, )* ) diff --git a/macros/src/command/slash.rs b/macros/src/command/slash.rs index bb316e8c9a1..81bf1d4e98c 100644 --- a/macros/src/command/slash.rs +++ b/macros/src/command/slash.rs @@ -190,7 +190,8 @@ pub fn generate_slash_action(inv: &Invocation) -> Result as core::ops::Deref>::deref(&cooldown_config)); } inner(ctx.into(), #( #param_identifiers, )*) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 802be75e9b2..e319593b408 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -69,10 +69,15 @@ for example for command-specific help (i.e. `~help command_name`). Escape newlin ## Cooldown - `manual_cooldowns`: Allows overriding the framework's built-in cooldowns tracking without affecting other commands. - `global_cooldown`: Minimum duration in seconds between invocations, globally +- `global_cooldown_burst`: The number of times the command can be invoked within the cooldown period, globally. The default is 1. - `user_cooldown`: Minimum duration in seconds between invocations, per user +- `user_cooldown_burst`: The number of times the command can be invoked within the cooldown period, per user. The default is 1. - `guild_cooldown`: Minimum duration in seconds between invocations, per guild +- `guild_cooldown_burst`: The number of times the command can be invoked within the cooldown period, per guild. The default is 1. - `channel_cooldown`: Minimum duration in seconds between invocations, per channel +- `channel_cooldown_burst`: The number of times the command can be invoked within the cooldown period, per channel. The default is 1. - `member_cooldown`: Minimum duration in seconds between invocations, per guild member +- `member_cooldown_burst`: The number of times the command can be invoked within the cooldown period, per guild member. The default is 1. ## Other diff --git a/src/cooldown.rs b/src/cooldown.rs index e33f12fe9dd..7fc78188bd6 100644 --- a/src/cooldown.rs +++ b/src/cooldown.rs @@ -5,6 +5,11 @@ use crate::serenity_prelude as serenity; use std::collections::HashMap; use std::time::{Duration, Instant}; +/// The starting value that gets assigned when inserted into [`CooldownTracker`] +const STARTING_BURST_AMOUNT: u64 = 1; +/// The default burst amount if it is not provided in [`CooldownConfig`] +const DEFAULT_BURST_AMOUNT: u64 = 1; + /// Subset of [`crate::Context`] so that [`Cooldowns`] can be used without requiring a full [Context](`crate::Context`) /// (ie from within an `event_handler`) #[derive(Default, Clone, PartialEq, Eq, Debug, Hash)] @@ -22,14 +27,24 @@ pub struct CooldownContext { pub struct CooldownConfig { /// This cooldown operates on a global basis pub global: Option, + /// This is how many operations can be invoked within the cooldown period on a global basis + pub global_burst_amount: Option, /// This cooldown operates on a per-user basis pub user: Option, + /// This is how many operations can be invoked within the cooldown period on a per-user basis + pub user_burst_amount: Option, /// This cooldown operates on a per-guild basis pub guild: Option, + /// This is how many operations can be invoked within the cooldown period on a per-guild basis + pub guild_burst_amount: Option, /// This cooldown operates on a per-channel basis pub channel: Option, + /// This is how many operations can be invoked within the cooldown period on a per-channel basis + pub channel_burst_amount: Option, /// This cooldown operates on a per-member basis pub member: Option, + /// This is how many operations can be invoked within the cooldown period on a per-member basis + pub member_burst_amount: Option, #[doc(hidden)] pub __non_exhaustive: (), } @@ -41,15 +56,15 @@ pub struct CooldownConfig { #[derive(Default, Clone, Debug, PartialEq, Eq)] pub struct CooldownTracker { /// Stores the timestamp of the last global invocation - global_invocation: Option, + global_invocation: Option<(Instant, u64)>, /// Stores the timestamps of the last invocation per user - user_invocations: HashMap, + user_invocations: HashMap, /// Stores the timestamps of the last invocation per guild - guild_invocations: HashMap, + guild_invocations: HashMap, /// Stores the timestamps of the last invocation per channel - channel_invocations: HashMap, + channel_invocations: HashMap, /// Stores the timestamps of the last invocation per member (user and guild) - member_invocations: HashMap<(serenity::UserId, serenity::GuildId), Instant>, + member_invocations: HashMap<(serenity::UserId, serenity::GuildId), (Instant, u64)>, } /// Possible types of command cooldowns. @@ -92,13 +107,19 @@ impl CooldownTracker { cooldown_durations: &CooldownConfig, ) -> Option { let mut cooldown_data = vec![ - (cooldown_durations.global, self.global_invocation), + ( + cooldown_durations.global, + cooldown_durations.global_burst_amount, + self.global_invocation, + ), ( cooldown_durations.user, + cooldown_durations.user_burst_amount, self.user_invocations.get(&ctx.user_id).copied(), ), ( cooldown_durations.channel, + cooldown_durations.channel_burst_amount, self.channel_invocations.get(&ctx.channel_id).copied(), ), ]; @@ -106,10 +127,12 @@ impl CooldownTracker { if let Some(guild_id) = ctx.guild_id { cooldown_data.push(( cooldown_durations.guild, + cooldown_durations.guild_burst_amount, self.guild_invocations.get(&guild_id).copied(), )); cooldown_data.push(( cooldown_durations.member, + cooldown_durations.member_burst_amount, self.member_invocations .get(&(ctx.user_id, guild_id)) .copied(), @@ -118,8 +141,12 @@ impl CooldownTracker { cooldown_data .iter() - .filter_map(|&(cooldown, last_invocation)| { - let duration_since = Instant::now().saturating_duration_since(last_invocation?); + .filter_map(|&(cooldown, burst_amount, last_invocation)| { + let last_invocation = last_invocation?; + if burst_amount.unwrap_or(DEFAULT_BURST_AMOUNT) > last_invocation.1 { + return None; + } + let duration_since = Instant::now().saturating_duration_since(last_invocation.0); let cooldown_left = cooldown?.checked_sub(duration_since)?; Some(cooldown_left) }) @@ -130,13 +157,14 @@ impl CooldownTracker { pub fn start_cooldown(&mut self, ctx: CooldownContext) { let now = Instant::now(); - self.global_invocation = Some(now); - self.user_invocations.insert(ctx.user_id, now); - self.channel_invocations.insert(ctx.channel_id, now); + self.global_invocation = Some((now, STARTING_BURST_AMOUNT)); + self.user_invocations.insert(ctx.user_id, (now, STARTING_BURST_AMOUNT)); + self.channel_invocations.insert(ctx.channel_id, (now, STARTING_BURST_AMOUNT)); if let Some(guild_id) = ctx.guild_id { - self.guild_invocations.insert(guild_id, now); - self.member_invocations.insert((ctx.user_id, guild_id), now); + self.guild_invocations.insert(guild_id, (now, STARTING_BURST_AMOUNT)); + self.member_invocations + .insert((ctx.user_id, guild_id), (now, STARTING_BURST_AMOUNT)); } } @@ -146,22 +174,136 @@ impl CooldownTracker { /// flexibility in cases where you might want to shorten or lengthen a cooldown after /// invocation. pub fn set_last_invocation(&mut self, cooldown_type: CooldownType, instant: Instant) { + // FIXME: decide whether burst amounts should just be set strictly to 1 or if the old value + // should cascade match cooldown_type { - CooldownType::Global => self.global_invocation = Some(instant), + CooldownType::Global => { + self.global_invocation = Some(( + instant, + self.global_invocation + .map_or(STARTING_BURST_AMOUNT, |(_, burst_amount)| burst_amount), + )) + } CooldownType::User(user_id) => { - self.user_invocations.insert(user_id, instant); + self.user_invocations.insert( + user_id, + ( + instant, + self.global_invocation + .map_or(STARTING_BURST_AMOUNT, |(_, burst_amount)| burst_amount), + ), + ); } CooldownType::Guild(guild_id) => { - self.guild_invocations.insert(guild_id, instant); + self.guild_invocations.insert( + guild_id, + ( + instant, + self.global_invocation + .map_or(STARTING_BURST_AMOUNT, |(_, burst_amount)| burst_amount), + ), + ); } CooldownType::Channel(channel_id) => { - self.channel_invocations.insert(channel_id, instant); + self.channel_invocations.insert( + channel_id, + ( + instant, + self.global_invocation + .map_or(STARTING_BURST_AMOUNT, |(_, burst_amount)| burst_amount), + ), + ); } CooldownType::Member(member) => { - self.member_invocations.insert(member, instant); + self.member_invocations.insert( + member, + ( + instant, + self.global_invocation + .map_or(STARTING_BURST_AMOUNT, |(_, burst_amount)| burst_amount), + ), + ); } } } + + /// Increments burst counter by 1, and otherwise resets the cooldown if the initial cooldown + /// window has passed. + pub fn increment_usage(&mut self, ctx: CooldownContext, config: &CooldownConfig) { + let now = Instant::now(); + self.global_invocation = Some( + self.global_invocation + .and_then(|(window, burst_amount)| { + if let Some(global_cooldown) = config.global { + let duration_since = now.saturating_duration_since(window); + if duration_since <= global_cooldown { + return Some((window, burst_amount + 1)); + } + } + None + }) + .unwrap_or((now, STARTING_BURST_AMOUNT)), + ); + self.user_invocations + .entry(ctx.user_id) + .and_modify(|(window, burst_amount)| { + if let Some(user_cooldown) = config.user { + let duration_since = now.saturating_duration_since(*window); + if duration_since <= user_cooldown { + *burst_amount += 1; + return; + } + } + *window = now; + *burst_amount = STARTING_BURST_AMOUNT; + }) + .or_insert((now, STARTING_BURST_AMOUNT)); + self.channel_invocations + .entry(ctx.channel_id) + .and_modify(|(window, burst_amount)| { + if let Some(channel_cooldown) = config.channel { + let duration_since = now.saturating_duration_since(*window); + if duration_since <= channel_cooldown { + *burst_amount += 1; + return; + } + } + *window = now; + *burst_amount = STARTING_BURST_AMOUNT; + }) + .or_insert((now, STARTING_BURST_AMOUNT)); + + if let Some(guild_id) = ctx.guild_id { + self.guild_invocations + .entry(guild_id) + .and_modify(|(window, burst_amount)| { + if let Some(guild_cooldown) = config.guild { + let duration_since = now.saturating_duration_since(*window); + if duration_since <= guild_cooldown { + *burst_amount += 1; + return; + } + } + *window = now; + *burst_amount = STARTING_BURST_AMOUNT; + }) + .or_insert((now, STARTING_BURST_AMOUNT)); + self.member_invocations + .entry((ctx.user_id, guild_id)) + .and_modify(|(window, burst_amount)| { + if let Some(member_cooldown) = config.member { + let duration_since = now.saturating_duration_since(*window); + if duration_since <= member_cooldown { + *burst_amount += 1; + return; + } + } + *window = now; + *burst_amount = STARTING_BURST_AMOUNT; + }) + .or_insert((now, STARTING_BURST_AMOUNT)); + } + } } impl<'a> From<&'a serenity::Message> for CooldownContext { @@ -211,7 +353,7 @@ mod test { ..Default::default() }; let tracker = CooldownTracker { - global_invocation: Some(Instant::now() - Duration::from_secs(1)), + global_invocation: Some((Instant::now() - Duration::from_secs(1), 1)), ..Default::default() }; let ctx = CooldownContext { From 6074287e44ae1df39492a4f99cf2d05e34c0126b Mon Sep 17 00:00:00 2001 From: Jareth Gomes Date: Sat, 4 Oct 2025 02:37:18 +0000 Subject: [PATCH 3/4] chore: Add tests for burst cooldowns --- src/cooldown.rs | 56 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/cooldown.rs b/src/cooldown.rs index 7fc78188bd6..9f1bc80d76e 100644 --- a/src/cooldown.rs +++ b/src/cooldown.rs @@ -703,4 +703,60 @@ mod test { assert!(cooldown < Duration::from_secs(1)); assert!(cooldown > Duration::from_secs(0)); } + + #[test] + fn global_bursts_do_not_return_cooldown_burst_count_is_exceeded() { + const BURST_AMOUNT: u64 = 10; + let config = CooldownConfig { + global: Some(Duration::from_secs(10)), + global_burst_amount: Some(BURST_AMOUNT), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + for _ in 0..BURST_AMOUNT { + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.increment_usage(ctx.clone(), &config); + } + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(10)); + assert!(cooldown > Duration::from_secs(9)); + } + + #[test] + fn member_bursts_do_not_return_cooldown_burst_count_is_exceeded() { + const BURST_AMOUNT: u64 = 10; + let config = CooldownConfig { + member: Some(Duration::from_secs(10)), + member_burst_amount: Some(BURST_AMOUNT), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + for _ in 0..BURST_AMOUNT { + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.increment_usage(ctx.clone(), &config); + } + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(10)); + assert!(cooldown > Duration::from_secs(9)); + } } From 8454622ccd7237b43e29ff1c100450080561bdba Mon Sep 17 00:00:00 2001 From: Jareth Gomes Date: Sat, 4 Oct 2025 03:08:48 +0000 Subject: [PATCH 4/4] chore: Add example command to feature_showcase --- examples/feature_showcase/checks.rs | 14 ++++++++++++++ examples/feature_showcase/main.rs | 1 + 2 files changed, 15 insertions(+) diff --git a/examples/feature_showcase/checks.rs b/examples/feature_showcase/checks.rs index 87adef8be85..a946a5e6a48 100644 --- a/examples/feature_showcase/checks.rs +++ b/examples/feature_showcase/checks.rs @@ -98,6 +98,20 @@ pub async fn cooldowns(ctx: Context<'_>) -> Result<(), Error> { Ok(()) } +// Burstable cooldowns +#[poise::command( + prefix_command, + track_edits, + slash_command, + user_cooldown = 20, + user_cooldown_burst = 3 +)] +pub async fn burstable_cooldown(ctx: Context<'_>) -> Result<(), Error> { + ctx.say("You successfully called this burstable command") + .await?; + Ok(()) +} + #[poise::command(prefix_command, slash_command)] pub async fn minmax( ctx: Context<'_>, diff --git a/examples/feature_showcase/main.rs b/examples/feature_showcase/main.rs index afa78ea93cb..88bc8c9a897 100644 --- a/examples/feature_showcase/main.rs +++ b/examples/feature_showcase/main.rs @@ -45,6 +45,7 @@ async fn main() { checks::delete(), checks::ferrisparty(), checks::cooldowns(), + checks::burstable_cooldown(), checks::minmax(), checks::get_guild_name(), checks::only_in_dms(),