diff --git a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs index 26a65c5..59a7c1e 100644 --- a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs +++ b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs @@ -76,7 +76,7 @@ public RedisTokenBucketManager( RateLimitTimestampKey = new RedisKey($"rl:tb:{{{partitionKey}}}:ts"); } - internal async Task TryAcquireLeaseAsync() + internal async Task TryAcquireLeaseAsync(int permitCount) { var database = _connectionMultiplexer.GetDatabase(); @@ -89,7 +89,7 @@ internal async Task TryAcquireLeaseAsync() tokens_per_period = (RedisValue)_options.TokensPerPeriod, token_limit = (RedisValue)_options.TokenLimit, replenish_period = (RedisValue)_options.ReplenishmentPeriod.TotalMilliseconds, - permit_count = (RedisValue)1D, + permit_count = (RedisValue)permitCount, current_time = (RedisValue)DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), }); diff --git a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs index 875123f..9c7d02f 100644 --- a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs +++ b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs @@ -62,7 +62,7 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.TokenLimit)); } - return AcquireAsyncCoreInternal(); + return AcquireAsyncCoreInternal(permitCount); } protected override RateLimitLease AttemptAcquireCore(int permitCount) @@ -71,14 +71,14 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) return FailedLease; } - private async ValueTask AcquireAsyncCoreInternal() + private async ValueTask AcquireAsyncCoreInternal(int permitCount) { var leaseContext = new TokenBucketLeaseContext { Limit = _options.TokenLimit, }; - var response = await _redisManager.TryAcquireLeaseAsync(); + var response = await _redisManager.TryAcquireLeaseAsync(permitCount); leaseContext.Allowed = response.Allowed; leaseContext.Count = response.Count; diff --git a/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs b/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs index acd9553..c7de18a 100644 --- a/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs +++ b/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs @@ -100,5 +100,28 @@ public async Task CanAcquireAsyncResource() using var lease2 = await limiter.AcquireAsync(); Assert.False(lease2.IsAcquired); } + + [Fact] + public async Task CanAcquireMultiPermits() + { + using var limiter = new RedisTokenBucketRateLimiter( + partitionKey: Guid.NewGuid().ToString(), + new RedisTokenBucketRateLimiterOptions + { + TokenLimit = 5, + TokensPerPeriod = 5, + ReplenishmentPeriod = TimeSpan.FromMinutes(1), + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + using var lease = await limiter.AcquireAsync(4); + Assert.True(lease.IsAcquired); + + using var lease2 = await limiter.AcquireAsync(3); + Assert.False(lease2.IsAcquired); + + using var lease3 = await limiter.AcquireAsync(1); + Assert.True(lease3.IsAcquired); + } } }