Ver código fonte

feat: return original value of setbit in redis (#1746)

Kevin Wan 3 anos atrás
pai
commit
e0fa8d820d
2 arquivos alterados com 21 adições e 11 exclusões
  1. 12 5
      core/stores/redis/redis.go
  2. 9 6
      core/stores/redis/redis_test.go

+ 12 - 5
core/stores/redis/redis.go

@@ -1404,21 +1404,28 @@ func (s *Redis) ScanCtx(ctx context.Context, cursor uint64, match string, count
 }
 
 // SetBit is the implementation of redis setbit command.
-func (s *Redis) SetBit(key string, offset int64, value int) error {
+func (s *Redis) SetBit(key string, offset int64, value int) (int, error) {
 	return s.SetBitCtx(context.Background(), key, offset, value)
 }
 
 // SetBitCtx is the implementation of redis setbit command.
-func (s *Redis) SetBitCtx(ctx context.Context, key string, offset int64, value int) error {
-	return s.brk.DoWithAcceptable(func() error {
+func (s *Redis) SetBitCtx(ctx context.Context, key string, offset int64, value int) (val int, err error) {
+	err = s.brk.DoWithAcceptable(func() error {
 		conn, err := getRedis(s)
 		if err != nil {
 			return err
 		}
 
-		_, err = conn.SetBit(ctx, key, offset, value).Result()
-		return err
+		v, err := conn.SetBit(ctx, key, offset, value).Result()
+		if err != nil {
+			return err
+		}
+
+		val = int(v)
+		return nil
 	}, acceptable)
+
+	return
 }
 
 // Sscan is the implementation of redis sscan command.

+ 9 - 6
core/stores/redis/redis_test.go

@@ -387,30 +387,33 @@ func TestRedis_Mget(t *testing.T) {
 
 func TestRedis_SetBit(t *testing.T) {
 	runOnRedis(t, func(client *Redis) {
-		err := New(client.Addr, badType()).SetBit("key", 1, 1)
+		_, err := New(client.Addr, badType()).SetBit("key", 1, 1)
 		assert.NotNil(t, err)
-		err = client.SetBit("key", 1, 1)
+		val, err := client.SetBit("key", 1, 1)
 		assert.Nil(t, err)
+		assert.Equal(t, 0, val)
 	})
 }
 
 func TestRedis_GetBit(t *testing.T) {
 	runOnRedis(t, func(client *Redis) {
-		err := client.SetBit("key", 2, 1)
+		val, err := client.SetBit("key", 2, 1)
 		assert.Nil(t, err)
+		assert.Equal(t, 0, val)
 		_, err = New(client.Addr, badType()).GetBit("key", 2)
 		assert.NotNil(t, err)
-		val, err := client.GetBit("key", 2)
+		v, err := client.GetBit("key", 2)
 		assert.Nil(t, err)
-		assert.Equal(t, 1, val)
+		assert.Equal(t, 1, v)
 	})
 }
 
 func TestRedis_BitCount(t *testing.T) {
 	runOnRedis(t, func(client *Redis) {
 		for i := 0; i < 11; i++ {
-			err := client.SetBit("key", int64(i), 1)
+			val, err := client.SetBit("key", int64(i), 1)
 			assert.Nil(t, err)
+			assert.Equal(t, 0, val)
 		}
 
 		_, err := New(client.Addr, badType()).BitCount("key", 0, -1)