Skip to content

Commit c52e4fd

Browse files
authored
enable allowN and allowAtMost to peek at contents (#47)
* add failing test for allowN increment of zero * fix lua script to allow increment by zero (effectively peek) In past versions of redis_rate, you could peek at the current state in the redis limiter by calling allowN but not incrementing. The move to lua script did not preserve this (unadvertised) ability. This adds it back, the only issue was trying to SET a redis row with an expiry of zero if the row didn't exist. Since the caller doesn't want to increment at all, this changes the SET to not be called if the value is absent. * same test and fix for allowAtMost increment by zero * add test for peeking after a write
1 parent 874a15a commit c52e4fd

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

lua.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ if remaining < 0 then
5757
end
5858
5959
local reset_after = new_tat - now
60-
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
60+
if reset_after > 0 then
61+
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
62+
end
6163
local retry_after = -1
6264
return {cost, remaining, tostring(retry_after), tostring(reset_after)}
6365
`)
@@ -121,7 +123,9 @@ local increment = emission_interval * cost
121123
local new_tat = tat + increment
122124
123125
local reset_after = new_tat - now
124-
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
126+
if reset_after > 0 then
127+
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
128+
end
125129
126130
return {
127131
cost,

rate_test.go

+60
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,36 @@ func TestAllow(t *testing.T) {
6868
require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond))
6969
}
7070

71+
func TestAllowN_IncrementZero(t *testing.T) {
72+
ctx := context.Background()
73+
l := rateLimiter()
74+
limit := redis_rate.PerSecond(10)
75+
76+
// Check for a row that's not there
77+
res, err := l.AllowN(ctx, "test_id", limit, 0)
78+
require.Nil(t, err)
79+
require.Equal(t, res.Allowed, 0)
80+
require.Equal(t, res.Remaining, 10)
81+
require.Equal(t, res.RetryAfter, time.Duration(-1))
82+
require.Equal(t, res.ResetAfter, time.Duration(0))
83+
84+
// Now increment it
85+
res, err = l.Allow(ctx, "test_id", limit)
86+
require.Nil(t, err)
87+
require.Equal(t, res.Allowed, 1)
88+
require.Equal(t, res.Remaining, 9)
89+
require.Equal(t, res.RetryAfter, time.Duration(-1))
90+
require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond))
91+
92+
// Peek again
93+
res, err = l.AllowN(ctx, "test_id", limit, 0)
94+
require.Nil(t, err)
95+
require.Equal(t, res.Allowed, 0)
96+
require.Equal(t, res.Remaining, 9)
97+
require.Equal(t, res.RetryAfter, time.Duration(-1))
98+
require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond))
99+
}
100+
71101
func TestRetryAfter(t *testing.T) {
72102
limit := redis_rate.Limit{
73103
Rate: 1,
@@ -146,6 +176,36 @@ func TestAllowAtMost(t *testing.T) {
146176
require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond))
147177
}
148178

179+
func TestAllowAtMost_IncrementZero(t *testing.T) {
180+
ctx := context.Background()
181+
l := rateLimiter()
182+
limit := redis_rate.PerSecond(10)
183+
184+
// Check for a row that isn't there
185+
res, err := l.AllowAtMost(ctx, "test_id", limit, 0)
186+
require.Nil(t, err)
187+
require.Equal(t, res.Allowed, 0)
188+
require.Equal(t, res.Remaining, 10)
189+
require.Equal(t, res.RetryAfter, time.Duration(-1))
190+
require.Equal(t, res.ResetAfter, time.Duration(0))
191+
192+
// Now increment it
193+
res, err = l.Allow(ctx, "test_id", limit)
194+
require.Nil(t, err)
195+
require.Equal(t, res.Allowed, 1)
196+
require.Equal(t, res.Remaining, 9)
197+
require.Equal(t, res.RetryAfter, time.Duration(-1))
198+
require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond))
199+
200+
// Peek again
201+
res, err = l.AllowAtMost(ctx, "test_id", limit, 0)
202+
require.Nil(t, err)
203+
require.Equal(t, res.Allowed, 0)
204+
require.Equal(t, res.Remaining, 9)
205+
require.Equal(t, res.RetryAfter, time.Duration(-1))
206+
require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond))
207+
}
208+
149209
func BenchmarkAllow(b *testing.B) {
150210
ctx := context.Background()
151211
l := rateLimiter()

0 commit comments

Comments
 (0)