diff --git a/examples/feature_showcase/checks.rs b/examples/feature_showcase/checks.rs index 87adef8be85..a946a5e6a48 100644 --- a/examples/feature_showcase/checks.rs +++ b/examples/feature_showcase/checks.rs @@ -98,6 +98,20 @@ pub async fn cooldowns(ctx: Context<'_>) -> Result<(), Error> { Ok(()) } +// Burstable cooldowns +#[poise::command( + prefix_command, + track_edits, + slash_command, + user_cooldown = 20, + user_cooldown_burst = 3 +)] +pub async fn burstable_cooldown(ctx: Context<'_>) -> Result<(), Error> { + ctx.say("You successfully called this burstable command") + .await?; + Ok(()) +} + #[poise::command(prefix_command, slash_command)] pub async fn minmax( ctx: Context<'_>, diff --git a/examples/feature_showcase/main.rs b/examples/feature_showcase/main.rs index afa78ea93cb..88bc8c9a897 100644 --- a/examples/feature_showcase/main.rs +++ b/examples/feature_showcase/main.rs @@ -45,6 +45,7 @@ async fn main() { checks::delete(), checks::ferrisparty(), checks::cooldowns(), + checks::burstable_cooldown(), checks::minmax(), checks::get_guild_name(), checks::only_in_dms(), diff --git a/macros/src/command/mod.rs b/macros/src/command/mod.rs index 3c9d82feb1f..0e4e1f6b76b 100644 --- a/macros/src/command/mod.rs +++ b/macros/src/command/mod.rs @@ -59,6 +59,12 @@ pub struct CommandArgs { guild_cooldown: Option, channel_cooldown: Option, member_cooldown: Option, + + global_cooldown_burst: Option, + user_cooldown_burst: Option, + guild_cooldown_burst: Option, + channel_cooldown_burst: Option, + member_cooldown_burst: Option, } /// Representation of the function parameter attribute arguments @@ -429,18 +435,28 @@ fn generate_cooldown_config(args: &CommandArgs) -> proc_macro2::TokenStream { let to_seconds_path = quote::quote!(std::time::Duration::from_secs); let global_cooldown = wrap_option_and_map(args.global_cooldown, &to_seconds_path); + let global_burst = wrap_option(args.global_cooldown_burst); let user_cooldown = wrap_option_and_map(args.user_cooldown, &to_seconds_path); + let user_burst = wrap_option(args.user_cooldown_burst); let guild_cooldown = wrap_option_and_map(args.guild_cooldown, &to_seconds_path); + let guild_burst = wrap_option(args.guild_cooldown_burst); let channel_cooldown = wrap_option_and_map(args.channel_cooldown, &to_seconds_path); + let channel_burst = wrap_option(args.channel_cooldown_burst); let member_cooldown = wrap_option_and_map(args.member_cooldown, &to_seconds_path); + let member_burst = wrap_option(args.member_cooldown_burst); quote::quote!( std::sync::RwLock::new(::poise::CooldownConfig { global: #global_cooldown, + global_burst_amount: #global_burst, user: #user_cooldown, + user_burst_amount: #user_burst, guild: #guild_cooldown, + guild_burst_amount: #guild_burst, channel: #channel_cooldown, + channel_burst_amount: #channel_burst, member: #member_cooldown, + member_burst_amount: #member_burst, __non_exhaustive: () }) ) diff --git a/macros/src/command/prefix.rs b/macros/src/command/prefix.rs index d4750ae1216..bdb84961dfb 100644 --- a/macros/src/command/prefix.rs +++ b/macros/src/command/prefix.rs @@ -70,7 +70,8 @@ pub fn generate_prefix_action(inv: &Invocation) -> Result as core::ops::Deref>::deref(&cooldown_config)); } inner(ctx.into(), #( #param_idents, )* ) diff --git a/macros/src/command/slash.rs b/macros/src/command/slash.rs index bb316e8c9a1..81bf1d4e98c 100644 --- a/macros/src/command/slash.rs +++ b/macros/src/command/slash.rs @@ -190,7 +190,8 @@ pub fn generate_slash_action(inv: &Invocation) -> Result as core::ops::Deref>::deref(&cooldown_config)); } inner(ctx.into(), #( #param_identifiers, )*) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 802be75e9b2..e319593b408 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -69,10 +69,15 @@ for example for command-specific help (i.e. `~help command_name`). Escape newlin ## Cooldown - `manual_cooldowns`: Allows overriding the framework's built-in cooldowns tracking without affecting other commands. - `global_cooldown`: Minimum duration in seconds between invocations, globally +- `global_cooldown_burst`: The number of times the command can be invoked within the cooldown period, globally. The default is 1. - `user_cooldown`: Minimum duration in seconds between invocations, per user +- `user_cooldown_burst`: The number of times the command can be invoked within the cooldown period, per user. The default is 1. - `guild_cooldown`: Minimum duration in seconds between invocations, per guild +- `guild_cooldown_burst`: The number of times the command can be invoked within the cooldown period, per guild. The default is 1. - `channel_cooldown`: Minimum duration in seconds between invocations, per channel +- `channel_cooldown_burst`: The number of times the command can be invoked within the cooldown period, per channel. The default is 1. - `member_cooldown`: Minimum duration in seconds between invocations, per guild member +- `member_cooldown_burst`: The number of times the command can be invoked within the cooldown period, per guild member. The default is 1. ## Other diff --git a/src/cooldown.rs b/src/cooldown.rs index 8afba194de8..9f1bc80d76e 100644 --- a/src/cooldown.rs +++ b/src/cooldown.rs @@ -5,6 +5,11 @@ use crate::serenity_prelude as serenity; use std::collections::HashMap; use std::time::{Duration, Instant}; +/// The starting value that gets assigned when inserted into [`CooldownTracker`] +const STARTING_BURST_AMOUNT: u64 = 1; +/// The default burst amount if it is not provided in [`CooldownConfig`] +const DEFAULT_BURST_AMOUNT: u64 = 1; + /// Subset of [`crate::Context`] so that [`Cooldowns`] can be used without requiring a full [Context](`crate::Context`) /// (ie from within an `event_handler`) #[derive(Default, Clone, PartialEq, Eq, Debug, Hash)] @@ -22,14 +27,24 @@ pub struct CooldownContext { pub struct CooldownConfig { /// This cooldown operates on a global basis pub global: Option, + /// This is how many operations can be invoked within the cooldown period on a global basis + pub global_burst_amount: Option, /// This cooldown operates on a per-user basis pub user: Option, + /// This is how many operations can be invoked within the cooldown period on a per-user basis + pub user_burst_amount: Option, /// This cooldown operates on a per-guild basis pub guild: Option, + /// This is how many operations can be invoked within the cooldown period on a per-guild basis + pub guild_burst_amount: Option, /// This cooldown operates on a per-channel basis pub channel: Option, + /// This is how many operations can be invoked within the cooldown period on a per-channel basis + pub channel_burst_amount: Option, /// This cooldown operates on a per-member basis pub member: Option, + /// This is how many operations can be invoked within the cooldown period on a per-member basis + pub member_burst_amount: Option, #[doc(hidden)] pub __non_exhaustive: (), } @@ -41,15 +56,15 @@ pub struct CooldownConfig { #[derive(Default, Clone, Debug, PartialEq, Eq)] pub struct CooldownTracker { /// Stores the timestamp of the last global invocation - global_invocation: Option, + global_invocation: Option<(Instant, u64)>, /// Stores the timestamps of the last invocation per user - user_invocations: HashMap, + user_invocations: HashMap, /// Stores the timestamps of the last invocation per guild - guild_invocations: HashMap, + guild_invocations: HashMap, /// Stores the timestamps of the last invocation per channel - channel_invocations: HashMap, + channel_invocations: HashMap, /// Stores the timestamps of the last invocation per member (user and guild) - member_invocations: HashMap<(serenity::UserId, serenity::GuildId), Instant>, + member_invocations: HashMap<(serenity::UserId, serenity::GuildId), (Instant, u64)>, } /// Possible types of command cooldowns. @@ -92,13 +107,19 @@ impl CooldownTracker { cooldown_durations: &CooldownConfig, ) -> Option { let mut cooldown_data = vec![ - (cooldown_durations.global, self.global_invocation), + ( + cooldown_durations.global, + cooldown_durations.global_burst_amount, + self.global_invocation, + ), ( cooldown_durations.user, + cooldown_durations.user_burst_amount, self.user_invocations.get(&ctx.user_id).copied(), ), ( cooldown_durations.channel, + cooldown_durations.channel_burst_amount, self.channel_invocations.get(&ctx.channel_id).copied(), ), ]; @@ -106,10 +127,12 @@ impl CooldownTracker { if let Some(guild_id) = ctx.guild_id { cooldown_data.push(( cooldown_durations.guild, + cooldown_durations.guild_burst_amount, self.guild_invocations.get(&guild_id).copied(), )); cooldown_data.push(( cooldown_durations.member, + cooldown_durations.member_burst_amount, self.member_invocations .get(&(ctx.user_id, guild_id)) .copied(), @@ -118,8 +141,12 @@ impl CooldownTracker { cooldown_data .iter() - .filter_map(|&(cooldown, last_invocation)| { - let duration_since = Instant::now().saturating_duration_since(last_invocation?); + .filter_map(|&(cooldown, burst_amount, last_invocation)| { + let last_invocation = last_invocation?; + if burst_amount.unwrap_or(DEFAULT_BURST_AMOUNT) > last_invocation.1 { + return None; + } + let duration_since = Instant::now().saturating_duration_since(last_invocation.0); let cooldown_left = cooldown?.checked_sub(duration_since)?; Some(cooldown_left) }) @@ -130,13 +157,14 @@ impl CooldownTracker { pub fn start_cooldown(&mut self, ctx: CooldownContext) { let now = Instant::now(); - self.global_invocation = Some(now); - self.user_invocations.insert(ctx.user_id, now); - self.channel_invocations.insert(ctx.channel_id, now); + self.global_invocation = Some((now, STARTING_BURST_AMOUNT)); + self.user_invocations.insert(ctx.user_id, (now, STARTING_BURST_AMOUNT)); + self.channel_invocations.insert(ctx.channel_id, (now, STARTING_BURST_AMOUNT)); if let Some(guild_id) = ctx.guild_id { - self.guild_invocations.insert(guild_id, now); - self.member_invocations.insert((ctx.user_id, guild_id), now); + self.guild_invocations.insert(guild_id, (now, STARTING_BURST_AMOUNT)); + self.member_invocations + .insert((ctx.user_id, guild_id), (now, STARTING_BURST_AMOUNT)); } } @@ -146,22 +174,136 @@ impl CooldownTracker { /// flexibility in cases where you might want to shorten or lengthen a cooldown after /// invocation. pub fn set_last_invocation(&mut self, cooldown_type: CooldownType, instant: Instant) { + // FIXME: decide whether burst amounts should just be set strictly to 1 or if the old value + // should cascade match cooldown_type { - CooldownType::Global => self.global_invocation = Some(instant), + CooldownType::Global => { + self.global_invocation = Some(( + instant, + self.global_invocation + .map_or(STARTING_BURST_AMOUNT, |(_, burst_amount)| burst_amount), + )) + } CooldownType::User(user_id) => { - self.user_invocations.insert(user_id, instant); + self.user_invocations.insert( + user_id, + ( + instant, + self.global_invocation + .map_or(STARTING_BURST_AMOUNT, |(_, burst_amount)| burst_amount), + ), + ); } CooldownType::Guild(guild_id) => { - self.guild_invocations.insert(guild_id, instant); + self.guild_invocations.insert( + guild_id, + ( + instant, + self.global_invocation + .map_or(STARTING_BURST_AMOUNT, |(_, burst_amount)| burst_amount), + ), + ); } CooldownType::Channel(channel_id) => { - self.channel_invocations.insert(channel_id, instant); + self.channel_invocations.insert( + channel_id, + ( + instant, + self.global_invocation + .map_or(STARTING_BURST_AMOUNT, |(_, burst_amount)| burst_amount), + ), + ); } CooldownType::Member(member) => { - self.member_invocations.insert(member, instant); + self.member_invocations.insert( + member, + ( + instant, + self.global_invocation + .map_or(STARTING_BURST_AMOUNT, |(_, burst_amount)| burst_amount), + ), + ); } } } + + /// Increments burst counter by 1, and otherwise resets the cooldown if the initial cooldown + /// window has passed. + pub fn increment_usage(&mut self, ctx: CooldownContext, config: &CooldownConfig) { + let now = Instant::now(); + self.global_invocation = Some( + self.global_invocation + .and_then(|(window, burst_amount)| { + if let Some(global_cooldown) = config.global { + let duration_since = now.saturating_duration_since(window); + if duration_since <= global_cooldown { + return Some((window, burst_amount + 1)); + } + } + None + }) + .unwrap_or((now, STARTING_BURST_AMOUNT)), + ); + self.user_invocations + .entry(ctx.user_id) + .and_modify(|(window, burst_amount)| { + if let Some(user_cooldown) = config.user { + let duration_since = now.saturating_duration_since(*window); + if duration_since <= user_cooldown { + *burst_amount += 1; + return; + } + } + *window = now; + *burst_amount = STARTING_BURST_AMOUNT; + }) + .or_insert((now, STARTING_BURST_AMOUNT)); + self.channel_invocations + .entry(ctx.channel_id) + .and_modify(|(window, burst_amount)| { + if let Some(channel_cooldown) = config.channel { + let duration_since = now.saturating_duration_since(*window); + if duration_since <= channel_cooldown { + *burst_amount += 1; + return; + } + } + *window = now; + *burst_amount = STARTING_BURST_AMOUNT; + }) + .or_insert((now, STARTING_BURST_AMOUNT)); + + if let Some(guild_id) = ctx.guild_id { + self.guild_invocations + .entry(guild_id) + .and_modify(|(window, burst_amount)| { + if let Some(guild_cooldown) = config.guild { + let duration_since = now.saturating_duration_since(*window); + if duration_since <= guild_cooldown { + *burst_amount += 1; + return; + } + } + *window = now; + *burst_amount = STARTING_BURST_AMOUNT; + }) + .or_insert((now, STARTING_BURST_AMOUNT)); + self.member_invocations + .entry((ctx.user_id, guild_id)) + .and_modify(|(window, burst_amount)| { + if let Some(member_cooldown) = config.member { + let duration_since = now.saturating_duration_since(*window); + if duration_since <= member_cooldown { + *burst_amount += 1; + return; + } + } + *window = now; + *burst_amount = STARTING_BURST_AMOUNT; + }) + .or_insert((now, STARTING_BURST_AMOUNT)); + } + } } impl<'a> From<&'a serenity::Message> for CooldownContext { @@ -173,3 +315,448 @@ impl<'a> From<&'a serenity::Message> for CooldownContext { } } } + +#[cfg(test)] +mod test { + use ::serenity::all::{ChannelId, GuildId, UserId}; + + use super::*; + + #[test] + fn start_cooldown_triggers_guild_cooldown() { + let config = CooldownConfig { + global: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + } + + #[test] + fn cooldown_resets_after_window_expires() { + let config = CooldownConfig { + global: Some(Duration::from_secs(1)), + ..Default::default() + }; + let tracker = CooldownTracker { + global_invocation: Some((Instant::now() - Duration::from_secs(1), 1)), + ..Default::default() + }; + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + } + + #[tokio::test] + async fn basic_global_cooldown() { + let config = CooldownConfig { + global: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx.clone(), &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + tokio::time::sleep(cooldown).await; + assert!(tracker.remaining_cooldown(ctx, &config).is_none()); + } + + #[test] + fn global_cooldown_affects_other_users() { + let config = CooldownConfig { + global: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let other_user_ctx = CooldownContext { + user_id: UserId::from(54321), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + let cooldown = tracker.remaining_cooldown(other_user_ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + } + + #[tokio::test] + async fn basic_user_cooldown() { + let config = CooldownConfig { + user: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx.clone(), &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + tokio::time::sleep(cooldown).await; + assert!(tracker.remaining_cooldown(ctx, &config).is_none()); + } + + #[test] + fn user_cooldown_does_not_affect_other_users() { + let config = CooldownConfig { + user: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let other_user_ctx = CooldownContext { + user_id: UserId::from(54321), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + assert!(tracker + .remaining_cooldown(other_user_ctx, &config) + .is_none()); + + let same_user_different_channel_ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(9876), + }; + + let cooldown = tracker.remaining_cooldown(same_user_different_channel_ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + } + + #[tokio::test] + async fn basic_channel_cooldown() { + let config = CooldownConfig { + channel: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx.clone(), &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + tokio::time::sleep(cooldown).await; + assert!(tracker.remaining_cooldown(ctx, &config).is_none()) + } + + #[test] + fn channel_cooldown_affects_one_channel() { + let config = CooldownConfig { + channel: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let other_user_ctx = CooldownContext { + user_id: UserId::from(54321), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + let cooldown = tracker.remaining_cooldown(other_user_ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let same_user_different_channel_ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(9876), + }; + + assert!(tracker + .remaining_cooldown(same_user_different_channel_ctx, &config) + .is_none()); + } + + #[tokio::test] + async fn basic_guild_cooldown() { + let config = CooldownConfig { + guild: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx.clone(), &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + tokio::time::sleep(cooldown).await; + assert!(tracker.remaining_cooldown(ctx, &config).is_none()); + } + + #[test] + fn guild_cooldown_affects_one_guild() { + let config = CooldownConfig { + guild: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let other_user_ctx = CooldownContext { + user_id: UserId::from(54321), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + let cooldown = tracker.remaining_cooldown(other_user_ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + // This is not realistic since while guild id is different, the channel id is the same, but + // this is to demostrate the guild id affects the cooldown. + let same_user_different_guild_ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(420)), + channel_id: ChannelId::from(67890), + }; + + assert!(tracker + .remaining_cooldown(same_user_different_guild_ctx, &config) + .is_none()); + } + + #[tokio::test] + async fn basic_member_cooldown() { + let config = CooldownConfig { + member: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx.clone(), &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + tokio::time::sleep(cooldown).await; + assert!(tracker.remaining_cooldown(ctx, &config).is_none()); + } + + #[test] + fn member_cooldown_affects_one_member() { + let config = CooldownConfig { + member: Some(Duration::from_secs(1)), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + tracker.start_cooldown(ctx.clone()); + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + + let other_user_ctx = CooldownContext { + user_id: UserId::from(54321), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + assert!(tracker + .remaining_cooldown(other_user_ctx, &config) + .is_none()); + + let same_user_different_channel_ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(9876), + }; + let cooldown = tracker.remaining_cooldown(same_user_different_channel_ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(1)); + assert!(cooldown > Duration::from_secs(0)); + } + + #[test] + fn global_bursts_do_not_return_cooldown_burst_count_is_exceeded() { + const BURST_AMOUNT: u64 = 10; + let config = CooldownConfig { + global: Some(Duration::from_secs(10)), + global_burst_amount: Some(BURST_AMOUNT), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: None, + channel_id: ChannelId::from(67890), + }; + + for _ in 0..BURST_AMOUNT { + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.increment_usage(ctx.clone(), &config); + } + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(10)); + assert!(cooldown > Duration::from_secs(9)); + } + + #[test] + fn member_bursts_do_not_return_cooldown_burst_count_is_exceeded() { + const BURST_AMOUNT: u64 = 10; + let config = CooldownConfig { + member: Some(Duration::from_secs(10)), + member_burst_amount: Some(BURST_AMOUNT), + ..Default::default() + }; + let mut tracker = CooldownTracker::default(); + let ctx = CooldownContext { + user_id: UserId::from(12345), + guild_id: Some(GuildId::from(1337)), + channel_id: ChannelId::from(67890), + }; + + for _ in 0..BURST_AMOUNT { + assert!(tracker.remaining_cooldown(ctx.clone(), &config).is_none()); + + tracker.increment_usage(ctx.clone(), &config); + } + + let cooldown = tracker.remaining_cooldown(ctx, &config); + assert!(cooldown.is_some()); + let cooldown = cooldown.unwrap(); + assert!(cooldown < Duration::from_secs(10)); + assert!(cooldown > Duration::from_secs(9)); + } +}