Просмотр исходного кода

feat(redis): add ScriptRun API and migrate EvalCtx to ScriptRun for limit, lock and bloom (#3087)

cong 2 лет назад
Родитель
Сommit
5da8a93c75

+ 12 - 10
core/bloom/bloom.go

@@ -11,25 +11,27 @@ import (
 const (
 	// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
 	// maps as k in the error rate table
-	maps      = 14
-	setScript = `
+	maps = 14
+)
+
+var (
+	// ErrTooLargeOffset indicates the offset is too large in bitset.
+	ErrTooLargeOffset = errors.New("too large offset")
+	setScript         = redis.NewScript(`
 for _, offset in ipairs(ARGV) do
 	redis.call("setbit", KEYS[1], offset, 1)
 end
-`
-	testScript = `
+`)
+	testScript = redis.NewScript(`
 for _, offset in ipairs(ARGV) do
 	if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then
 		return false
 	end
 end
 return true
-`
+`)
 )
 
-// ErrTooLargeOffset indicates the offset is too large in bitset.
-var ErrTooLargeOffset = errors.New("too large offset")
-
 type (
 	// A Filter is a bloom filter.
 	Filter struct {
@@ -117,7 +119,7 @@ func (r *redisBitSet) check(offsets []uint) (bool, error) {
 		return false, err
 	}
 
-	resp, err := r.store.Eval(testScript, []string{r.key}, args)
+	resp, err := r.store.ScriptRun(testScript, []string{r.key}, args)
 	if err == redis.Nil {
 		return false, nil
 	} else if err != nil {
@@ -147,7 +149,7 @@ func (r *redisBitSet) set(offsets []uint) error {
 		return err
 	}
 
-	_, err = r.store.Eval(setScript, []string{r.key}, args)
+	_, err = r.store.ScriptRun(setScript, []string{r.key}, args)
 	if err == redis.Nil {
 		return nil
 	}

+ 20 - 18
core/limit/periodlimit.go

@@ -9,21 +9,6 @@ import (
 	"github.com/zeromicro/go-zero/core/stores/redis"
 )
 
-// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
-const periodScript = `local limit = tonumber(ARGV[1])
-local window = tonumber(ARGV[2])
-local current = redis.call("INCRBY", KEYS[1], 1)
-if current == 1 then
-    redis.call("expire", KEYS[1], window)
-end
-if current < limit then
-    return 1
-elseif current == limit then
-    return 2
-else
-    return 0
-end`
-
 const (
 	// Unknown means not initialized state.
 	Unknown = iota
@@ -39,8 +24,25 @@ const (
 	internalHitQuota  = 2
 )
 
-// ErrUnknownCode is an error that represents unknown status code.
-var ErrUnknownCode = errors.New("unknown status code")
+var (
+	// ErrUnknownCode is an error that represents unknown status code.
+	ErrUnknownCode = errors.New("unknown status code")
+
+	// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
+	periodScript = redis.NewScript(`local limit = tonumber(ARGV[1])
+local window = tonumber(ARGV[2])
+local current = redis.call("INCRBY", KEYS[1], 1)
+if current == 1 then
+    redis.call("expire", KEYS[1], window)
+end
+if current < limit then
+    return 1
+elseif current == limit then
+    return 2
+else
+    return 0
+end`)
+)
 
 type (
 	// PeriodOption defines the method to customize a PeriodLimit.
@@ -80,7 +82,7 @@ func (h *PeriodLimit) Take(key string) (int, error) {
 
 // TakeCtx requests a permit with context, it returns the permit state.
 func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) {
-	resp, err := h.limitStore.EvalCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{
+	resp, err := h.limitStore.ScriptRunCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{
 		strconv.Itoa(h.quota),
 		strconv.Itoa(h.calcExpireSeconds()),
 	})

+ 11 - 10
core/limit/tokenlimit.go

@@ -15,10 +15,15 @@ import (
 )
 
 const (
-	// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
-	// KEYS[1] as tokens_key
-	// KEYS[2] as timestamp_key
-	script = `local rate = tonumber(ARGV[1])
+	tokenFormat     = "{%s}.tokens"
+	timestampFormat = "{%s}.ts"
+	pingInterval    = time.Millisecond * 100
+)
+
+// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
+// KEYS[1] as tokens_key
+// KEYS[2] as timestamp_key
+var script = redis.NewScript(`local rate = tonumber(ARGV[1])
 local capacity = tonumber(ARGV[2])
 local now = tonumber(ARGV[3])
 local requested = tonumber(ARGV[4])
@@ -45,11 +50,7 @@ end
 redis.call("setex", KEYS[1], ttl, new_tokens)
 redis.call("setex", KEYS[2], ttl, now)
 
-return allowed`
-	tokenFormat     = "{%s}.tokens"
-	timestampFormat = "{%s}.ts"
-	pingInterval    = time.Millisecond * 100
-)
+return allowed`)
 
 // A TokenLimiter controls how frequently events are allowed to happen with in one second.
 type TokenLimiter struct {
@@ -110,7 +111,7 @@ func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) boo
 		return lim.rescueLimiter.AllowN(now, n)
 	}
 
-	resp, err := lim.store.EvalCtx(ctx,
+	resp, err := lim.store.ScriptRunCtx(ctx,
 		script,
 		[]string{
 			lim.tokenKey,

+ 26 - 0
core/stores/redis/redis.go

@@ -87,6 +87,8 @@ type (
 	FloatCmd = red.FloatCmd
 	// StringCmd is an alias of redis.StringCmd.
 	StringCmd = red.StringCmd
+	// Script is an alias of redis.Script.
+	Script = red.Script
 )
 
 // New returns a Redis with given options.
@@ -145,6 +147,11 @@ func newRedis(addr string, opts ...Option) *Redis {
 	return r
 }
 
+// NewScript returns a new Script instance.
+func NewScript(script string) *Script {
+	return red.NewScript(script)
+}
+
 // BitCount is redis bitcount command implementation.
 func (s *Redis) BitCount(key string, start, end int64) (int64, error) {
 	return s.BitCountCtx(context.Background(), key, start, end)
@@ -1630,6 +1637,25 @@ func (s *Redis) ScriptLoadCtx(ctx context.Context, script string) (string, error
 	return conn.ScriptLoad(ctx, script).Result()
 }
 
+// ScriptRun is the implementation of *redis.Script run command.
+func (s *Redis) ScriptRun(script *Script, keys []string, args ...any) (any, error) {
+	return s.ScriptRunCtx(context.Background(), script, keys, args...)
+}
+
+// ScriptRunCtx is the implementation of *redis.Script run command.
+func (s *Redis) ScriptRunCtx(ctx context.Context, script *Script, keys []string, args ...any) (val any, err error) {
+	err = s.brk.DoWithAcceptable(func() error {
+		conn, err := getRedis(s)
+		if err != nil {
+			return err
+		}
+
+		val, err = script.Run(ctx, conn, keys, args...).Result()
+		return err
+	}, acceptable)
+	return
+}
+
 // Set is the implementation of redis set command.
 func (s *Redis) Set(key, value string) error {
 	return s.SetCtx(context.Background(), key, value)

+ 18 - 0
core/stores/redis/redis_test.go

@@ -240,6 +240,24 @@ func TestRedis_Eval(t *testing.T) {
 	})
 }
 
+func TestRedis_ScriptRun(t *testing.T) {
+	runOnRedis(t, func(client *Redis) {
+		sc := NewScript(`redis.call("EXISTS", KEYS[1])`)
+		sc2 := NewScript(`return redis.call("EXISTS", KEYS[1])`)
+		_, err := New(client.Addr, badType()).ScriptRun(sc, []string{"notexist"})
+		assert.NotNil(t, err)
+		_, err = client.ScriptRun(sc, []string{"notexist"})
+		assert.Equal(t, Nil, err)
+		err = client.Set("key1", "value1")
+		assert.Nil(t, err)
+		_, err = client.ScriptRun(sc, []string{"key1"})
+		assert.Equal(t, Nil, err)
+		val, err := client.ScriptRun(sc2, []string{"key1"})
+		assert.Nil(t, err)
+		assert.Equal(t, int64(1), val)
+	})
+}
+
 func TestRedis_GeoHash(t *testing.T) {
 	runOnRedis(t, func(client *Redis) {
 		_, err := client.GeoHash("parent", "child1", "child2")

+ 9 - 6
core/stores/redis/redislock.go

@@ -17,17 +17,20 @@ const (
 	randomLen       = 16
 	tolerance       = 500 // milliseconds
 	millisPerSecond = 1000
-	lockCommand     = `if redis.call("GET", KEYS[1]) == ARGV[1] then
+)
+
+var (
+	lockScript = NewScript(`if redis.call("GET", KEYS[1]) == ARGV[1] then
     redis.call("SET", KEYS[1], ARGV[1], "PX", ARGV[2])
     return "OK"
 else
     return redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2])
-end`
-	delCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then
+end`)
+	delScript = NewScript(`if redis.call("GET", KEYS[1]) == ARGV[1] then
     return redis.call("DEL", KEYS[1])
 else
     return 0
-end`
+end`)
 )
 
 // A RedisLock is a redis lock.
@@ -59,7 +62,7 @@ func (rl *RedisLock) Acquire() (bool, error) {
 // AcquireCtx acquires the lock with the given ctx.
 func (rl *RedisLock) AcquireCtx(ctx context.Context) (bool, error) {
 	seconds := atomic.LoadUint32(&rl.seconds)
-	resp, err := rl.store.EvalCtx(ctx, lockCommand, []string{rl.key}, []string{
+	resp, err := rl.store.ScriptRunCtx(ctx, lockScript, []string{rl.key}, []string{
 		rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance),
 	})
 	if err == red.Nil {
@@ -87,7 +90,7 @@ func (rl *RedisLock) Release() (bool, error) {
 
 // ReleaseCtx releases the lock with the given ctx.
 func (rl *RedisLock) ReleaseCtx(ctx context.Context) (bool, error) {
-	resp, err := rl.store.EvalCtx(ctx, delCommand, []string{rl.key}, []string{rl.id})
+	resp, err := rl.store.ScriptRunCtx(ctx, delScript, []string{rl.key}, []string{rl.id})
 	if err != nil {
 		return false, err
 	}