123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- package limit
- import (
- "context"
- "errors"
- "strconv"
- "time"
- "github.com/wuntsong-org/go-zero-plus/core/stores/redis"
- )
- const (
- // Unknown means not initialized state.
- Unknown = iota
- // Allowed means allowed state.
- Allowed
- // HitQuota means this request exactly hit the quota.
- HitQuota
- // OverQuota means passed the quota.
- OverQuota
- internalOverQuota = 0
- internalAllowed = 1
- internalHitQuota = 2
- )
- var (
- // ErrUnknownCode is an error that represents unknown status code.
- ErrUnknownCode = errors.New("unknown status code")
- // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
- periodScript = redis.NewScript(`local limit = tonumber(ARGV[1])
- local window = tonumber(ARGV[2])
- local current = redis.call("INCRBY", KEYS[1], 1)
- if current == 1 then
- redis.call("expire", KEYS[1], window)
- end
- if current < limit then
- return 1
- elseif current == limit then
- return 2
- else
- return 0
- end`)
- )
- type (
- // PeriodOption defines the method to customize a PeriodLimit.
- PeriodOption func(l *PeriodLimit)
- // A PeriodLimit is used to limit requests during a period of time.
- PeriodLimit struct {
- period int
- quota int
- limitStore *redis.Redis
- keyPrefix string
- align bool
- }
- )
- // NewPeriodLimit returns a PeriodLimit with given parameters.
- func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string,
- opts ...PeriodOption) *PeriodLimit {
- limiter := &PeriodLimit{
- period: period,
- quota: quota,
- limitStore: limitStore,
- keyPrefix: keyPrefix,
- }
- for _, opt := range opts {
- opt(limiter)
- }
- return limiter
- }
- // Take requests a permit, it returns the permit state.
- func (h *PeriodLimit) Take(key string) (int, error) {
- return h.TakeCtx(context.Background(), key)
- }
- // TakeCtx requests a permit with context, it returns the permit state.
- func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) {
- resp, err := h.limitStore.ScriptRunCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{
- strconv.Itoa(h.quota),
- strconv.Itoa(h.calcExpireSeconds()),
- })
- if err != nil {
- return Unknown, err
- }
- code, ok := resp.(int64)
- if !ok {
- return Unknown, ErrUnknownCode
- }
- switch code {
- case internalOverQuota:
- return OverQuota, nil
- case internalAllowed:
- return Allowed, nil
- case internalHitQuota:
- return HitQuota, nil
- default:
- return Unknown, ErrUnknownCode
- }
- }
- func (h *PeriodLimit) calcExpireSeconds() int {
- if h.align {
- now := time.Now()
- _, offset := now.Zone()
- unix := now.Unix() + int64(offset)
- return h.period - int(unix%int64(h.period))
- }
- return h.period
- }
- // Align returns a func to customize a PeriodLimit with alignment.
- // For example, if we want to limit end users with 5 sms verification messages every day,
- // we need to align with the local timezone and the start of the day.
- func Align() PeriodOption {
- return func(l *PeriodLimit) {
- l.align = true
- }
- }
|