periodlimit.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. package limit
  2. import (
  3. "context"
  4. "errors"
  5. "strconv"
  6. "time"
  7. "github.com/wuntsong-org/go-zero-plus/core/stores/redis"
  8. )
  9. const (
  10. // Unknown means not initialized state.
  11. Unknown = iota
  12. // Allowed means allowed state.
  13. Allowed
  14. // HitQuota means this request exactly hit the quota.
  15. HitQuota
  16. // OverQuota means passed the quota.
  17. OverQuota
  18. internalOverQuota = 0
  19. internalAllowed = 1
  20. internalHitQuota = 2
  21. )
  22. var (
  23. // ErrUnknownCode is an error that represents unknown status code.
  24. ErrUnknownCode = errors.New("unknown status code")
  25. // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
  26. periodScript = redis.NewScript(`local limit = tonumber(ARGV[1])
  27. local window = tonumber(ARGV[2])
  28. local current = redis.call("INCRBY", KEYS[1], 1)
  29. if current == 1 then
  30. redis.call("expire", KEYS[1], window)
  31. end
  32. if current < limit then
  33. return 1
  34. elseif current == limit then
  35. return 2
  36. else
  37. return 0
  38. end`)
  39. )
  40. type (
  41. // PeriodOption defines the method to customize a PeriodLimit.
  42. PeriodOption func(l *PeriodLimit)
  43. // A PeriodLimit is used to limit requests during a period of time.
  44. PeriodLimit struct {
  45. period int
  46. quota int
  47. limitStore *redis.Redis
  48. keyPrefix string
  49. align bool
  50. }
  51. )
  52. // NewPeriodLimit returns a PeriodLimit with given parameters.
  53. func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string,
  54. opts ...PeriodOption) *PeriodLimit {
  55. limiter := &PeriodLimit{
  56. period: period,
  57. quota: quota,
  58. limitStore: limitStore,
  59. keyPrefix: keyPrefix,
  60. }
  61. for _, opt := range opts {
  62. opt(limiter)
  63. }
  64. return limiter
  65. }
  66. // Take requests a permit, it returns the permit state.
  67. func (h *PeriodLimit) Take(key string) (int, error) {
  68. return h.TakeCtx(context.Background(), key)
  69. }
  70. // TakeCtx requests a permit with context, it returns the permit state.
  71. func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) {
  72. resp, err := h.limitStore.ScriptRunCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{
  73. strconv.Itoa(h.quota),
  74. strconv.Itoa(h.calcExpireSeconds()),
  75. })
  76. if err != nil {
  77. return Unknown, err
  78. }
  79. code, ok := resp.(int64)
  80. if !ok {
  81. return Unknown, ErrUnknownCode
  82. }
  83. switch code {
  84. case internalOverQuota:
  85. return OverQuota, nil
  86. case internalAllowed:
  87. return Allowed, nil
  88. case internalHitQuota:
  89. return HitQuota, nil
  90. default:
  91. return Unknown, ErrUnknownCode
  92. }
  93. }
  94. func (h *PeriodLimit) calcExpireSeconds() int {
  95. if h.align {
  96. now := time.Now()
  97. _, offset := now.Zone()
  98. unix := now.Unix() + int64(offset)
  99. return h.period - int(unix%int64(h.period))
  100. }
  101. return h.period
  102. }
  103. // Align returns a func to customize a PeriodLimit with alignment.
  104. // For example, if we want to limit end users with 5 sms verification messages every day,
  105. // we need to align with the local timezone and the start of the day.
  106. func Align() PeriodOption {
  107. return func(l *PeriodLimit) {
  108. l.align = true
  109. }
  110. }