Skip to content

Commit 1e4c1e2

Browse files
committed
Add AllowAtMostN
1 parent 46c7bc5 commit 1e4c1e2

File tree

4 files changed

+196
-29
lines changed

4 files changed

+196
-29
lines changed

example_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ func ExampleNewLimiter() {
2020
if err != nil {
2121
panic(err)
2222
}
23-
fmt.Println(res.Allowed, res.Remaining)
24-
// Output: true 9
23+
fmt.Println("allowed", res.Allowed, "remaining", res.Remaining)
24+
// Output: allowed 1 remaining 9
2525
}

lua.go

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,19 @@ import "github.com/go-redis/redis/v8"
44

55
// Copyright (c) 2017 Pavel Pravosud
66
// https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua
7-
var gcra = redis.NewScript(`
7+
var allowN = redis.NewScript(`
88
-- this script has side-effects, so it requires replicate commands mode
99
redis.replicate_commands()
1010
1111
local rate_limit_key = KEYS[1]
1212
local burst = ARGV[1]
1313
local rate = ARGV[2]
1414
local period = ARGV[3]
15-
local cost = ARGV[4]
15+
local cost = tonumber(ARGV[4])
1616
1717
local emission_interval = period / rate
1818
local increment = emission_interval * cost
1919
local burst_offset = emission_interval * burst
20-
local now = redis.call("TIME")
2120
2221
-- redis returns time as an array containing two integers: seconds of the epoch
2322
-- time (10 digits) and microseconds (6 digits). for convenience we need to
@@ -27,6 +26,7 @@ local now = redis.call("TIME")
2726
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
2827
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
2928
local jan_1_2017 = 1483228800
29+
local now = redis.call("TIME")
3030
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
3131
3232
local tat = redis.call("GET", rate_limit_key)
@@ -37,28 +37,86 @@ else
3737
tat = tonumber(tat)
3838
end
3939
40-
local new_tat = math.max(tat, now) + increment
40+
tat = math.max(tat, now)
4141
42+
local new_tat = tat + increment
4243
local allow_at = new_tat - burst_offset
44+
4345
local diff = now - allow_at
46+
local remaining = math.floor(diff / emission_interval + 0.5)
47+
48+
if remaining >= 0 then
49+
local reset_after = new_tat - now
50+
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
51+
local retry_after = -1
52+
return {cost, remaining, tostring(retry_after), tostring(reset_after)}
53+
end
54+
55+
remaining = 0
56+
local reset_after = tat - now
57+
local retry_after = diff * -1
58+
return {0, remaining, tostring(retry_after), tostring(reset_after)}
59+
`)
60+
61+
var allowAtMost = redis.NewScript(`
62+
-- this script has side-effects, so it requires replicate commands mode
63+
redis.replicate_commands()
64+
65+
local rate_limit_key = KEYS[1]
66+
local burst = ARGV[1]
67+
local rate = ARGV[2]
68+
local period = ARGV[3]
69+
local cost = tonumber(ARGV[4])
70+
71+
local emission_interval = period / rate
72+
local burst_offset = emission_interval * burst
73+
74+
-- redis returns time as an array containing two integers: seconds of the epoch
75+
-- time (10 digits) and microseconds (6 digits). for convenience we need to
76+
-- convert them to a floating point number. the resulting number is 16 digits,
77+
-- bordering on the limits of a 64-bit double-precision floating point number.
78+
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
79+
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
80+
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
81+
local jan_1_2017 = 1483228800
82+
local now = redis.call("TIME")
83+
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
84+
85+
local tat = redis.call("GET", rate_limit_key)
86+
87+
if not tat then
88+
tat = now
89+
else
90+
tat = tonumber(tat)
91+
end
4492
45-
local limited
46-
local retry_after
47-
local reset_after
93+
tat = math.max(tat, now)
4894
49-
local remaining = math.floor(diff / emission_interval + 0.5) -- poor man's round
95+
local diff = now - (tat - burst_offset)
96+
local remaining = math.floor(diff / emission_interval + 0.5)
5097
51-
if remaining < 0 then
52-
limited = 1
98+
if remaining == 0 then
99+
return {
100+
0, -- allowed
101+
0, -- remaining
102+
}
103+
end
104+
105+
if remaining < cost then
106+
cost = remaining
53107
remaining = 0
54-
reset_after = tat - now
55-
retry_after = diff * -1
56108
else
57-
limited = 0
58-
reset_after = new_tat - now
59-
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
60-
retry_after = -1
109+
remaining = remaining - cost
61110
end
62111
63-
return {limited, remaining, tostring(retry_after), tostring(reset_after)}
112+
local increment = emission_interval * cost
113+
local new_tat = tat + increment
114+
115+
local reset_after = new_tat - now
116+
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
117+
118+
return {
119+
cost,
120+
remaining,
121+
}
64122
`)

rate.go

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,20 @@ func NewLimiter(rdb rediser) *Limiter {
6161
}
6262
}
6363

64-
// Allow is shorthand for AllowN(key, 1).
64+
// Allow is a shortcut for AllowN(ctx, key, limit, 1).
6565
func (l *Limiter) Allow(ctx context.Context, key string, limit *Limit) (*Result, error) {
6666
return l.AllowN(ctx, key, limit, 1)
6767
}
6868

6969
// AllowN reports whether n events may happen at time now.
70-
func (l *Limiter) AllowN(ctx context.Context, key string, limit *Limit, n int) (*Result, error) {
70+
func (l *Limiter) AllowN(
71+
ctx context.Context,
72+
key string,
73+
limit *Limit,
74+
n int,
75+
) (*Result, error) {
7176
values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
72-
v, err := gcra.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
77+
v, err := allowN.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
7378
if err != nil {
7479
return nil, err
7580
}
@@ -88,14 +93,38 @@ func (l *Limiter) AllowN(ctx context.Context, key string, limit *Limit, n int) (
8893

8994
res := &Result{
9095
Limit: limit,
91-
Allowed: values[0].(int64) == 0,
96+
Allowed: int(values[0].(int64)),
9297
Remaining: int(values[1].(int64)),
9398
RetryAfter: dur(retryAfter),
9499
ResetAfter: dur(resetAfter),
95100
}
96101
return res, nil
97102
}
98103

104+
// AllowAtMostN reports whether at most n events may happen at time now.
105+
// It returns number of allowed events. RetryAfter and ResetAfter are not set.
106+
func (l *Limiter) AllowAtMostN(
107+
ctx context.Context,
108+
key string,
109+
limit *Limit,
110+
n int,
111+
) (*Result, error) {
112+
values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
113+
v, err := allowAtMost.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
114+
if err != nil {
115+
return nil, err
116+
}
117+
118+
values = v.([]interface{})
119+
120+
res := &Result{
121+
Limit: limit,
122+
Allowed: int(values[0].(int64)),
123+
Remaining: int(values[1].(int64)),
124+
}
125+
return res, nil
126+
}
127+
99128
func dur(f float64) time.Duration {
100129
if f == -1 {
101130
return -1
@@ -108,7 +137,7 @@ type Result struct {
108137
Limit *Limit
109138

110139
// Allowed reports whether event may happen at time now.
111-
Allowed bool
140+
Allowed int
112141

113142
// Remaining is the maximum number of requests that could be
114143
// permitted instantaneously for this key given the current

rate_test.go

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,39 +29,119 @@ func TestAllow(t *testing.T) {
2929

3030
res, err := l.Allow(ctx, "test_id", limit)
3131
assert.Nil(t, err)
32-
assert.True(t, res.Allowed)
32+
assert.Equal(t, res.Allowed, 1)
3333
assert.Equal(t, res.Remaining, 9)
3434
assert.Equal(t, res.RetryAfter, time.Duration(-1))
3535
assert.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond))
3636

3737
res, err = l.AllowN(ctx, "test_id", limit, 2)
3838
assert.Nil(t, err)
39-
assert.True(t, res.Allowed)
39+
assert.Equal(t, res.Allowed, 2)
4040
assert.Equal(t, res.Remaining, 7)
4141
assert.Equal(t, res.RetryAfter, time.Duration(-1))
4242
assert.InDelta(t, res.ResetAfter, 300*time.Millisecond, float64(10*time.Millisecond))
4343

44+
res, err = l.AllowN(ctx, "test_id", limit, 7)
45+
assert.Nil(t, err)
46+
assert.Equal(t, res.Allowed, 7)
47+
assert.Equal(t, res.Remaining, 0)
48+
assert.Equal(t, res.RetryAfter, time.Duration(-1))
49+
assert.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond))
50+
4451
res, err = l.AllowN(ctx, "test_id", limit, 1000)
4552
assert.Nil(t, err)
46-
assert.False(t, res.Allowed)
53+
assert.Equal(t, res.Allowed, 0)
4754
assert.Equal(t, res.Remaining, 0)
4855
assert.InDelta(t, res.RetryAfter, 99*time.Second, float64(time.Second))
56+
assert.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond))
57+
}
58+
59+
func TestAllowAtMostN(t *testing.T) {
60+
ctx := context.Background()
61+
62+
l := rateLimiter()
63+
limit := redis_rate.PerSecond(10)
64+
65+
res, err := l.Allow(ctx, "test_id", limit)
66+
assert.Nil(t, err)
67+
assert.Equal(t, res.Allowed, 1)
68+
assert.Equal(t, res.Remaining, 9)
69+
assert.Equal(t, res.RetryAfter, time.Duration(-1))
70+
assert.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond))
71+
72+
res, err = l.AllowAtMostN(ctx, "test_id", limit, 2)
73+
assert.Nil(t, err)
74+
assert.Equal(t, res.Allowed, 2)
75+
assert.Equal(t, res.Remaining, 7)
76+
77+
res, err = l.AllowN(ctx, "test_id", limit, 0)
78+
assert.Nil(t, err)
79+
assert.Equal(t, res.Allowed, 0)
80+
assert.Equal(t, res.Remaining, 7)
81+
assert.Equal(t, res.RetryAfter, time.Duration(-1))
4982
assert.InDelta(t, res.ResetAfter, 300*time.Millisecond, float64(10*time.Millisecond))
83+
84+
res, err = l.AllowAtMostN(ctx, "test_id", limit, 10)
85+
assert.Nil(t, err)
86+
assert.Equal(t, res.Allowed, 7)
87+
assert.Equal(t, res.Remaining, 0)
88+
89+
res, err = l.AllowN(ctx, "test_id", limit, 0)
90+
assert.Nil(t, err)
91+
assert.Equal(t, res.Allowed, 0)
92+
assert.Equal(t, res.Remaining, 0)
93+
assert.Equal(t, res.RetryAfter, time.Duration(-1))
94+
assert.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond))
95+
96+
res, err = l.AllowAtMostN(ctx, "test_id", limit, 1000)
97+
assert.Nil(t, err)
98+
assert.Equal(t, res.Allowed, 0)
99+
assert.Equal(t, res.Remaining, 0)
100+
101+
res, err = l.AllowN(ctx, "test_id", limit, 1000)
102+
assert.Nil(t, err)
103+
assert.Equal(t, res.Allowed, 0)
104+
assert.Equal(t, res.Remaining, 0)
105+
assert.InDelta(t, res.RetryAfter, 99*time.Second, float64(time.Second))
106+
assert.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond))
50107
}
51108

52109
func BenchmarkAllow(b *testing.B) {
53110
ctx := context.Background()
54111
l := rateLimiter()
55-
limit := redis_rate.PerSecond(10000)
112+
limit := redis_rate.PerSecond(1e6)
56113

57114
b.ResetTimer()
58115

59116
b.RunParallel(func(pb *testing.PB) {
60117
for pb.Next() {
61-
_, err := l.Allow(ctx, "foo", limit)
118+
res, err := l.Allow(ctx, "foo", limit)
62119
if err != nil {
63120
b.Fatal(err)
64121
}
122+
if res.Allowed == 0 {
123+
panic("not reached")
124+
}
125+
}
126+
})
127+
}
128+
129+
func BenchmarkAllowAtMostN(b *testing.B) {
130+
ctx := context.Background()
131+
l := rateLimiter()
132+
limit := redis_rate.PerSecond(1e6)
133+
134+
b.ResetTimer()
135+
136+
b.RunParallel(func(pb *testing.PB) {
137+
for pb.Next() {
138+
res, err := l.AllowAtMostN(ctx, "foo", limit, 1)
139+
if err != nil {
140+
b.Fatal(err)
141+
}
142+
if res.Allowed == 0 {
143+
panic("not reached")
144+
}
65145
}
66146
})
67147
}

0 commit comments

Comments
 (0)