periodlimit.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package limit
  2. import (
  3. "context"
  4. "errors"
  5. "strconv"
  6. "time"
  7. "github.com/zeromicro/go-zero/core/stores/redis"
  8. )
  9. // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
  10. const periodScript = `local limit = tonumber(ARGV[1])
  11. local window = tonumber(ARGV[2])
  12. local current = redis.call("INCRBY", KEYS[1], 1)
  13. if current == 1 then
  14. redis.call("expire", KEYS[1], window)
  15. end
  16. if current < limit then
  17. return 1
  18. elseif current == limit then
  19. return 2
  20. else
  21. return 0
  22. end`
  23. const (
  24. // Unknown means not initialized state.
  25. Unknown = iota
  26. // Allowed means allowed state.
  27. Allowed
  28. // HitQuota means this request exactly hit the quota.
  29. HitQuota
  30. // OverQuota means passed the quota.
  31. OverQuota
  32. internalOverQuota = 0
  33. internalAllowed = 1
  34. internalHitQuota = 2
  35. )
  36. // ErrUnknownCode is an error that represents unknown status code.
  37. var ErrUnknownCode = errors.New("unknown status code")
  38. type (
  39. // PeriodOption defines the method to customize a PeriodLimit.
  40. PeriodOption func(l *PeriodLimit)
  41. // A PeriodLimit is used to limit requests during a period of time.
  42. PeriodLimit struct {
  43. period int
  44. quota int
  45. limitStore *redis.Redis
  46. keyPrefix string
  47. align bool
  48. }
  49. )
  50. // NewPeriodLimit returns a PeriodLimit with given parameters.
  51. func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string,
  52. opts ...PeriodOption) *PeriodLimit {
  53. limiter := &PeriodLimit{
  54. period: period,
  55. quota: quota,
  56. limitStore: limitStore,
  57. keyPrefix: keyPrefix,
  58. }
  59. for _, opt := range opts {
  60. opt(limiter)
  61. }
  62. return limiter
  63. }
  64. // Take requests a permit, it returns the permit state.
  65. func (h *PeriodLimit) Take(key string) (int, error) {
  66. return h.TakeCtx(context.Background(), key)
  67. }
  68. // TakeCtx requests a permit with context, it returns the permit state.
  69. func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) {
  70. resp, err := h.limitStore.EvalCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{
  71. strconv.Itoa(h.quota),
  72. strconv.Itoa(h.calcExpireSeconds()),
  73. })
  74. if err != nil {
  75. return Unknown, err
  76. }
  77. code, ok := resp.(int64)
  78. if !ok {
  79. return Unknown, ErrUnknownCode
  80. }
  81. switch code {
  82. case internalOverQuota:
  83. return OverQuota, nil
  84. case internalAllowed:
  85. return Allowed, nil
  86. case internalHitQuota:
  87. return HitQuota, nil
  88. default:
  89. return Unknown, ErrUnknownCode
  90. }
  91. }
  92. func (h *PeriodLimit) calcExpireSeconds() int {
  93. if h.align {
  94. now := time.Now()
  95. _, offset := now.Zone()
  96. unix := now.Unix() + int64(offset)
  97. return h.period - int(unix%int64(h.period))
  98. }
  99. return h.period
  100. }
  101. // Align returns a func to customize a PeriodLimit with alignment.
  102. // For example, if we want to limit end users with 5 sms verification messages every day,
  103. // we need to align with the local timezone and the start of the day.
  104. func Align() PeriodOption {
  105. return func(l *PeriodLimit) {
  106. l.align = true
  107. }
  108. }