tokenlimit.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package limit
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strconv"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "github.com/wuntsong-org/go-zero-plus/core/logx"
  11. "github.com/wuntsong-org/go-zero-plus/core/stores/redis"
  12. xrate "golang.org/x/time/rate"
  13. )
  14. const (
  15. tokenFormat = "{%s}.tokens"
  16. timestampFormat = "{%s}.ts"
  17. pingInterval = time.Millisecond * 100
  18. )
  19. // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
  20. // KEYS[1] as tokens_key
  21. // KEYS[2] as timestamp_key
  22. var script = redis.NewScript(`local rate = tonumber(ARGV[1])
  23. local capacity = tonumber(ARGV[2])
  24. local now = tonumber(ARGV[3])
  25. local requested = tonumber(ARGV[4])
  26. local fill_time = capacity/rate
  27. local ttl = math.floor(fill_time*2)
  28. local last_tokens = tonumber(redis.call("get", KEYS[1]))
  29. if last_tokens == nil then
  30. last_tokens = capacity
  31. end
  32. local last_refreshed = tonumber(redis.call("get", KEYS[2]))
  33. if last_refreshed == nil then
  34. last_refreshed = 0
  35. end
  36. local delta = math.max(0, now-last_refreshed)
  37. local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
  38. local allowed = filled_tokens >= requested
  39. local new_tokens = filled_tokens
  40. if allowed then
  41. new_tokens = filled_tokens - requested
  42. end
  43. redis.call("setex", KEYS[1], ttl, new_tokens)
  44. redis.call("setex", KEYS[2], ttl, now)
  45. return allowed`)
  46. // A TokenLimiter controls how frequently events are allowed to happen with in one second.
  47. type TokenLimiter struct {
  48. rate int
  49. burst int
  50. store *redis.Redis
  51. tokenKey string
  52. timestampKey string
  53. rescueLock sync.Mutex
  54. redisAlive uint32
  55. monitorStarted bool
  56. rescueLimiter *xrate.Limiter
  57. }
  58. // NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits
  59. // bursts of at most burst tokens.
  60. func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter {
  61. tokenKey := fmt.Sprintf(tokenFormat, key)
  62. timestampKey := fmt.Sprintf(timestampFormat, key)
  63. return &TokenLimiter{
  64. rate: rate,
  65. burst: burst,
  66. store: store,
  67. tokenKey: tokenKey,
  68. timestampKey: timestampKey,
  69. redisAlive: 1,
  70. rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst),
  71. }
  72. }
  73. // Allow is shorthand for AllowN(time.Now(), 1).
  74. func (lim *TokenLimiter) Allow() bool {
  75. return lim.AllowN(time.Now(), 1)
  76. }
  77. // AllowCtx is shorthand for AllowNCtx(ctx,time.Now(), 1) with incoming context.
  78. func (lim *TokenLimiter) AllowCtx(ctx context.Context) bool {
  79. return lim.AllowNCtx(ctx, time.Now(), 1)
  80. }
  81. // AllowN reports whether n events may happen at time now.
  82. // Use this method if you intend to drop / skip events that exceed the rate.
  83. // Otherwise, use Reserve or Wait.
  84. func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
  85. return lim.reserveN(context.Background(), now, n)
  86. }
  87. // AllowNCtx reports whether n events may happen at time now with incoming context.
  88. // Use this method if you intend to drop / skip events that exceed the rate.
  89. // Otherwise, use Reserve or Wait.
  90. func (lim *TokenLimiter) AllowNCtx(ctx context.Context, now time.Time, n int) bool {
  91. return lim.reserveN(ctx, now, n)
  92. }
  93. func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) bool {
  94. if atomic.LoadUint32(&lim.redisAlive) == 0 {
  95. return lim.rescueLimiter.AllowN(now, n)
  96. }
  97. resp, err := lim.store.ScriptRunCtx(ctx,
  98. script,
  99. []string{
  100. lim.tokenKey,
  101. lim.timestampKey,
  102. },
  103. []string{
  104. strconv.Itoa(lim.rate),
  105. strconv.Itoa(lim.burst),
  106. strconv.FormatInt(now.Unix(), 10),
  107. strconv.Itoa(n),
  108. })
  109. // redis allowed == false
  110. // Lua boolean false -> r Nil bulk reply
  111. if err == redis.Nil {
  112. return false
  113. }
  114. if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
  115. logx.Errorf("fail to use rate limiter: %s", err)
  116. return false
  117. }
  118. if err != nil {
  119. logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
  120. lim.startMonitor()
  121. return lim.rescueLimiter.AllowN(now, n)
  122. }
  123. code, ok := resp.(int64)
  124. if !ok {
  125. logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp)
  126. lim.startMonitor()
  127. return lim.rescueLimiter.AllowN(now, n)
  128. }
  129. // redis allowed == true
  130. // Lua boolean true -> r integer reply with value of 1
  131. return code == 1
  132. }
  133. func (lim *TokenLimiter) startMonitor() {
  134. lim.rescueLock.Lock()
  135. defer lim.rescueLock.Unlock()
  136. if lim.monitorStarted {
  137. return
  138. }
  139. lim.monitorStarted = true
  140. atomic.StoreUint32(&lim.redisAlive, 0)
  141. go lim.waitForRedis()
  142. }
  143. func (lim *TokenLimiter) waitForRedis() {
  144. ticker := time.NewTicker(pingInterval)
  145. defer func() {
  146. ticker.Stop()
  147. lim.rescueLock.Lock()
  148. lim.monitorStarted = false
  149. lim.rescueLock.Unlock()
  150. }()
  151. for range ticker.C {
  152. if lim.store.Ping() {
  153. atomic.StoreUint32(&lim.redisAlive, 1)
  154. return
  155. }
  156. }
  157. }