periodlimit.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. package limit
  2. import (
  3. "errors"
  4. "strconv"
  5. "time"
  6. "github.com/zeromicro/go-zero/core/stores/redis"
  7. )
  8. // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
  9. const periodScript = `local limit = tonumber(ARGV[1])
  10. local window = tonumber(ARGV[2])
  11. local current = redis.call("INCRBY", KEYS[1], 1)
  12. if current == 1 then
  13. redis.call("expire", KEYS[1], window)
  14. return 1
  15. elseif current < limit then
  16. return 1
  17. elseif current == limit then
  18. return 2
  19. else
  20. return 0
  21. end`
  22. const (
  23. // Unknown means not initialized state.
  24. Unknown = iota
  25. // Allowed means allowed state.
  26. Allowed
  27. // HitQuota means this request exactly hit the quota.
  28. HitQuota
  29. // OverQuota means passed the quota.
  30. OverQuota
  31. internalOverQuota = 0
  32. internalAllowed = 1
  33. internalHitQuota = 2
  34. )
  35. // ErrUnknownCode is an error that represents unknown status code.
  36. var ErrUnknownCode = errors.New("unknown status code")
  37. type (
  38. // PeriodOption defines the method to customize a PeriodLimit.
  39. PeriodOption func(l *PeriodLimit)
  40. // A PeriodLimit is used to limit requests during a period of time.
  41. PeriodLimit struct {
  42. period int
  43. quota int
  44. limitStore *redis.Redis
  45. keyPrefix string
  46. align bool
  47. }
  48. )
  49. // NewPeriodLimit returns a PeriodLimit with given parameters.
  50. func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string,
  51. opts ...PeriodOption) *PeriodLimit {
  52. limiter := &PeriodLimit{
  53. period: period,
  54. quota: quota,
  55. limitStore: limitStore,
  56. keyPrefix: keyPrefix,
  57. }
  58. for _, opt := range opts {
  59. opt(limiter)
  60. }
  61. return limiter
  62. }
  63. // Take requests a permit, it returns the permit state.
  64. func (h *PeriodLimit) Take(key string) (int, error) {
  65. resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{
  66. strconv.Itoa(h.quota),
  67. strconv.Itoa(h.calcExpireSeconds()),
  68. })
  69. if err != nil {
  70. return Unknown, err
  71. }
  72. code, ok := resp.(int64)
  73. if !ok {
  74. return Unknown, ErrUnknownCode
  75. }
  76. switch code {
  77. case internalOverQuota:
  78. return OverQuota, nil
  79. case internalAllowed:
  80. return Allowed, nil
  81. case internalHitQuota:
  82. return HitQuota, nil
  83. default:
  84. return Unknown, ErrUnknownCode
  85. }
  86. }
  87. func (h *PeriodLimit) calcExpireSeconds() int {
  88. if h.align {
  89. now := time.Now()
  90. _, offset := now.Zone()
  91. unix := now.Unix() + int64(offset)
  92. return h.period - int(unix%int64(h.period))
  93. }
  94. return h.period
  95. }
  96. // Align returns a func to customize a PeriodLimit with alignment.
  97. // For example, if we want to limit end users with 5 sms verification messages every day,
  98. // we need to align with the local timezone and the start of the day.
  99. func Align() PeriodOption {
  100. return func(l *PeriodLimit) {
  101. l.align = true
  102. }
  103. }