Pārlūkot izejas kodu

token limit support context (#2335)

* token limit support context

* add token limit with ctx

add token limit with ctx

Co-authored-by: sado <liaoyonglin@bilibili.com>
sado 2 gadi atpakaļ
vecāks
revīzija
f068062b13
2 mainītis faili ar 55 papildinājumiem un 3 dzēšanām
  1. 30 3
      core/limit/tokenlimit.go
  2. 25 0
      core/limit/tokenlimit_test.go

+ 30 - 3
core/limit/tokenlimit.go

@@ -1,6 +1,8 @@
 package limit
 
 import (
+	"context"
+	"errors"
 	"fmt"
 	"strconv"
 	"sync"
@@ -84,19 +86,38 @@ func (lim *TokenLimiter) Allow() bool {
 	return lim.AllowN(time.Now(), 1)
 }
 
+// AllowCtx is shorthand for AllowNCtx(ctx,time.Now(), 1) with incoming context.
+func (lim *TokenLimiter) AllowCtx(ctx context.Context) bool {
+	return lim.AllowNCtx(ctx, time.Now(), 1)
+}
+
 // AllowN reports whether n events may happen at time now.
 // Use this method if you intend to drop / skip events that exceed the rate.
 // Otherwise, use Reserve or Wait.
 func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
-	return lim.reserveN(now, n)
+	return lim.reserveN(context.Background(), now, n)
 }
 
-func (lim *TokenLimiter) reserveN(now time.Time, n int) bool {
+// AllowNCtx reports whether n events may happen at time now with incoming context.
+// Use this method if you intend to drop / skip events that exceed the rate.
+// Otherwise, use Reserve or Wait.
+func (lim *TokenLimiter) AllowNCtx(ctx context.Context, now time.Time, n int) bool {
+	return lim.reserveN(ctx, now, n)
+}
+
+func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) bool {
+	select {
+	case <-ctx.Done():
+		logx.Errorf("fail to use rate limiter: %s", ctx.Err())
+		return false
+	default:
+	}
+
 	if atomic.LoadUint32(&lim.redisAlive) == 0 {
 		return lim.rescueLimiter.AllowN(now, n)
 	}
 
-	resp, err := lim.store.Eval(
+	resp, err := lim.store.EvalCtx(ctx,
 		script,
 		[]string{
 			lim.tokenKey,
@@ -113,6 +134,12 @@ func (lim *TokenLimiter) reserveN(now time.Time, n int) bool {
 	if err == redis.Nil {
 		return false
 	}
+
+	if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
+		logx.Errorf("fail to use rate limiter: %s", err)
+		return false
+	}
+
 	if err != nil {
 		logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
 		lim.startMonitor()

+ 25 - 0
core/limit/tokenlimit_test.go

@@ -1,6 +1,7 @@
 package limit
 
 import (
+	"context"
 	"testing"
 	"time"
 
@@ -15,6 +16,30 @@ func init() {
 	logx.Disable()
 }
 
+func TestTokenLimit_WithCtx(t *testing.T) {
+	s, err := miniredis.Run()
+	assert.Nil(t, err)
+
+	const (
+		total = 100
+		rate  = 5
+		burst = 10
+	)
+	l := NewTokenLimiter(rate, burst, redis.New(s.Addr()), "tokenlimit")
+	defer s.Close()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	ok := l.AllowCtx(ctx)
+	assert.True(t, ok)
+
+	cancel()
+	for i := 0; i < total; i++ {
+		ok := l.AllowCtx(ctx)
+		assert.False(t, ok)
+		assert.False(t, l.monitorStarted)
+	}
+}
+
 func TestTokenLimit_Rescue(t *testing.T) {
 	s, err := miniredis.Run()
 	assert.Nil(t, err)