Browse Source

feat: add getset command in redis and kv (#1693)

benqi 3 years ago
parent
commit
bbac994c8a

+ 10 - 0
core/stores/kv/store.go

@@ -54,6 +54,7 @@ type (
 		Setex(key, value string, seconds int) error
 		Setex(key, value string, seconds int) error
 		Setnx(key, value string) (bool, error)
 		Setnx(key, value string) (bool, error)
 		SetnxEx(key, value string, seconds int) (bool, error)
 		SetnxEx(key, value string, seconds int) (bool, error)
+		Getset(key, value string) (string, error)
 		Sismember(key string, value interface{}) (bool, error)
 		Sismember(key string, value interface{}) (bool, error)
 		Smembers(key string) ([]string, error)
 		Smembers(key string) ([]string, error)
 		Spop(key string) (string, error)
 		Spop(key string) (string, error)
@@ -459,6 +460,15 @@ func (cs clusterStore) SetnxEx(key, value string, seconds int) (bool, error) {
 	return node.SetnxEx(key, value, seconds)
 	return node.SetnxEx(key, value, seconds)
 }
 }
 
 
+func (cs clusterStore) Getset(key, value string) (string, error) {
+	node, err := cs.getRedis(key)
+	if err != nil {
+		return "", err
+	}
+
+	return node.GetSet(key, value)
+}
+
 func (cs clusterStore) Sismember(key string, value interface{}) (bool, error) {
 func (cs clusterStore) Sismember(key string, value interface{}) (bool, error) {
 	node, err := cs.getRedis(key)
 	node, err := cs.getRedis(key)
 	if err != nil {
 	if err != nil {

+ 23 - 0
core/stores/kv/store_test.go

@@ -490,6 +490,29 @@ func TestRedis_SetExNx(t *testing.T) {
 	})
 	})
 }
 }
 
 
+func TestRedis_Getset(t *testing.T) {
+	store := clusterStore{dispatcher: hash.NewConsistentHash()}
+	_, err := store.Getset("hello", "world")
+	assert.NotNil(t, err)
+
+	runOnCluster(t, func(client Store) {
+		val, err := client.Getset("hello", "world")
+		assert.Nil(t, err)
+		assert.Equal(t, "", val)
+		val, err = client.Get("hello")
+		assert.Nil(t, err)
+		assert.Equal(t, "world", val)
+		val, err = client.Getset("hello", "newworld")
+		assert.Nil(t, err)
+		assert.Equal(t, "world", val)
+		val, err = client.Get("hello")
+		assert.Nil(t, err)
+		assert.Equal(t, "newworld", val)
+		_, err = client.Del("hello")
+		assert.Nil(t, err)
+	})
+}
+
 func TestRedis_SetGetDelHashField(t *testing.T) {
 func TestRedis_SetGetDelHashField(t *testing.T) {
 	store := clusterStore{dispatcher: hash.NewConsistentHash()}
 	store := clusterStore{dispatcher: hash.NewConsistentHash()}
 	err := store.Hset("key", "field", "value")
 	err := store.Hset("key", "field", "value")

+ 25 - 0
core/stores/redis/redis.go

@@ -615,6 +615,31 @@ func (s *Redis) GetCtx(ctx context.Context, key string) (val string, err error)
 	return
 	return
 }
 }
 
 
+// GetSet is the implementation of redis getset command.
+func (s *Redis) GetSet(key, value string) (string, error) {
+	return s.GetSetCtx(context.Background(), key, value)
+}
+
+// GetSetCtx is the implementation of redis getset command.
+func (s *Redis) GetSetCtx(ctx context.Context, key, value string) (val string, err error) {
+	err = s.brk.DoWithAcceptable(func() error {
+		conn, err := getRedis(s)
+		if err != nil {
+			return err
+		}
+
+		if val, err = conn.GetSet(ctx, key, value).Result(); err == red.Nil {
+			return nil
+		} else if err != nil {
+			return err
+		} else {
+			return nil
+		}
+	}, acceptable)
+
+	return
+}
+
 // GetBit is the implementation of redis getbit command.
 // GetBit is the implementation of redis getbit command.
 func (s *Redis) GetBit(key string, offset int64) (int, error) {
 func (s *Redis) GetBit(key string, offset int64) (int, error) {
 	return s.GetBitCtx(context.Background(), key, offset)
 	return s.GetBitCtx(context.Background(), key, offset)

+ 22 - 0
core/stores/redis/redis_test.go

@@ -701,6 +701,28 @@ func TestRedis_Set(t *testing.T) {
 	})
 	})
 }
 }
 
 
+func TestRedis_GetSet(t *testing.T) {
+	runOnRedis(t, func(client *Redis) {
+		val, err := New(client.Addr, badType()).GetSet("hello", "world")
+		assert.NotNil(t, err)
+		val, err = client.GetSet("hello", "world")
+		assert.Nil(t, err)
+		assert.Equal(t, "", val)
+		val, err = client.Get("hello")
+		assert.Nil(t, err)
+		assert.Equal(t, "world", val)
+		val, err = client.GetSet("hello", "newworld")
+		assert.Nil(t, err)
+		assert.Equal(t, "world", val)
+		val, err = client.Get("hello")
+		assert.Nil(t, err)
+		assert.Equal(t, "newworld", val)
+		ret, err := client.Del("hello")
+		assert.Nil(t, err)
+		assert.Equal(t, 1, ret)
+	})
+}
+
 func TestRedis_SetGetDel(t *testing.T) {
 func TestRedis_SetGetDel(t *testing.T) {
 	runOnRedis(t, func(client *Redis) {
 	runOnRedis(t, func(client *Redis) {
 		err := New(client.Addr, badType()).Set("hello", "world")
 		err := New(client.Addr, badType()).Set("hello", "world")