Переглянути джерело

refactor(redistest): simplify redistest.CreateRedis API (#3086)

cong 2 роки тому
батько
коміт
b49fc81618

+ 9 - 14
core/bloom/bloom_test.go

@@ -4,13 +4,12 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/stores/redis/redistest"
 )
 
 func TestRedisBitSet_New_Set_Test(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	bitSet := newRedisBitSet(store, "test_key", 1024)
 	isSetBefore, err := bitSet.check([]uint{0})
@@ -42,9 +41,7 @@ func TestRedisBitSet_New_Set_Test(t *testing.T) {
 }
 
 func TestRedisBitSet_Add(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	filter := New(store, "test_key", 64)
 	assert.Nil(t, filter.Add([]byte("hello")))
@@ -55,11 +52,10 @@ func TestRedisBitSet_Add(t *testing.T) {
 }
 
 func TestFilter_Exists(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
+	store, clean := redistest.CreateRedisWithClean(t)
 
 	rbs := New(store, "test", 64)
-	_, err = rbs.Exists([]byte{0, 1, 2})
+	_, err := rbs.Exists([]byte{0, 1, 2})
 	assert.NoError(t, err)
 
 	clean()
@@ -69,12 +65,11 @@ func TestFilter_Exists(t *testing.T) {
 }
 
 func TestRedisBitSet_check(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
+	store, clean := redistest.CreateRedisWithClean(t)
 
 	rbs := newRedisBitSet(store, "test", 0)
 	assert.Error(t, rbs.set([]uint{0, 1, 2}))
-	_, err = rbs.check([]uint{0, 1, 2})
+	_, err := rbs.check([]uint{0, 1, 2})
 	assert.Error(t, err)
 
 	rbs = newRedisBitSet(store, "test", 64)
@@ -88,8 +83,8 @@ func TestRedisBitSet_check(t *testing.T) {
 }
 
 func TestRedisBitSet_set(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
+	logx.Disable()
+	store, clean := redistest.CreateRedisWithClean(t)
 
 	rbs := newRedisBitSet(store, "test", 0)
 	assert.Error(t, rbs.set([]uint{0, 1, 2}))

+ 1 - 3
core/limit/periodlimit_test.go

@@ -33,9 +33,7 @@ func TestPeriodLimit_RedisUnavailable(t *testing.T) {
 }
 
 func testPeriodLimit(t *testing.T, opts ...PeriodOption) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	const (
 		seconds = 1

+ 2 - 6
core/limit/tokenlimit_test.go

@@ -70,9 +70,7 @@ func TestTokenLimit_Rescue(t *testing.T) {
 }
 
 func TestTokenLimit_Take(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	const (
 		total = 100
@@ -92,9 +90,7 @@ func TestTokenLimit_Take(t *testing.T) {
 }
 
 func TestTokenLimit_TakeBurst(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	const (
 		total = 100

+ 3 - 9
core/stores/cache/cache_test.go

@@ -112,12 +112,8 @@ func (mc *mockedNode) TakeWithExpireCtx(ctx context.Context, val any, key string
 func TestCache_SetDel(t *testing.T) {
 	t.Run("test set del", func(t *testing.T) {
 		const total = 1000
-		r1, clean1, err := redistest.CreateRedis()
-		assert.Nil(t, err)
-		defer clean1()
-		r2, clean2, err := redistest.CreateRedis()
-		assert.Nil(t, err)
-		defer clean2()
+		r1 := redistest.CreateRedis(t)
+		r2 := redistest.CreateRedis(t)
 		conf := ClusterConf{
 			{
 				RedisConf: redis.RedisConf{
@@ -193,9 +189,7 @@ func TestCache_SetDel(t *testing.T) {
 
 func TestCache_OneNode(t *testing.T) {
 	const total = 1000
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 	conf := ClusterConf{
 		{
 			RedisConf: redis.RedisConf{

+ 13 - 31
core/stores/cache/cachenode_test.go

@@ -34,10 +34,8 @@ func init() {
 
 func TestCacheNode_DelCache(t *testing.T) {
 	t.Run("del cache", func(t *testing.T) {
-		store, clean, err := redistest.CreateRedis()
-		assert.Nil(t, err)
+		store := redistest.CreateRedis(t)
 		store.Type = redis.ClusterType
-		defer clean()
 
 		cn := cacheNode{
 			rds:            store,
@@ -84,9 +82,7 @@ func TestCacheNode_DelCache(t *testing.T) {
 }
 
 func TestCacheNode_DelCacheWithErrors(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 	store.Type = redis.ClusterType
 
 	cn := cacheNode{
@@ -122,9 +118,7 @@ func TestCacheNode_InvalidCache(t *testing.T) {
 }
 
 func TestCacheNode_SetWithExpire(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	cn := cacheNode{
 		rds:            store,
@@ -139,14 +133,12 @@ func TestCacheNode_SetWithExpire(t *testing.T) {
 }
 
 func TestCacheNode_Take(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	cn := NewNode(store, syncx.NewSingleFlight(), NewStat("any"), errTestNotFound,
 		WithExpiry(time.Second), WithNotFoundExpiry(time.Second))
 	var str string
-	err = cn.Take(&str, "any", func(v any) error {
+	err := cn.Take(&str, "any", func(v any) error {
 		*v.(*string) = "value"
 		return nil
 	})
@@ -174,9 +166,7 @@ func TestCacheNode_TakeBadRedis(t *testing.T) {
 }
 
 func TestCacheNode_TakeNotFound(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	cn := cacheNode{
 		rds:            store,
@@ -188,7 +178,7 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
 		errNotFound:    errTestNotFound,
 	}
 	var str string
-	err = cn.Take(&str, "any", func(v any) error {
+	err := cn.Take(&str, "any", func(v any) error {
 		return errTestNotFound
 	})
 	assert.True(t, cn.IsNotFound(err))
@@ -213,9 +203,7 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
 }
 
 func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.NoError(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	cn := cacheNode{
 		rds:            store,
@@ -228,7 +216,7 @@ func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
 	}
 
 	var str string
-	err = cn.Take(&str, "any", func(v any) error {
+	err := cn.Take(&str, "any", func(v any) error {
 		store.Set("any", "foo")
 		return errTestNotFound
 	})
@@ -242,9 +230,7 @@ func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
 }
 
 func TestCacheNode_TakeWithExpire(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	cn := cacheNode{
 		rds:            store,
@@ -256,7 +242,7 @@ func TestCacheNode_TakeWithExpire(t *testing.T) {
 		errNotFound:    errors.New("any"),
 	}
 	var str string
-	err = cn.TakeWithExpire(&str, "any", func(v any, expire time.Duration) error {
+	err := cn.TakeWithExpire(&str, "any", func(v any, expire time.Duration) error {
 		*v.(*string) = "value"
 		return nil
 	})
@@ -269,9 +255,7 @@ func TestCacheNode_TakeWithExpire(t *testing.T) {
 }
 
 func TestCacheNode_String(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	cn := cacheNode{
 		rds:            store,
@@ -286,9 +270,7 @@ func TestCacheNode_String(t *testing.T) {
 }
 
 func TestCacheValueWithBigInt(t *testing.T) {
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	cn := cacheNode{
 		rds:            store,

+ 9 - 20
core/stores/redis/redistest/redistest.go

@@ -1,31 +1,20 @@
 package redistest
 
 import (
-	"time"
+	"testing"
 
 	"github.com/alicebob/miniredis/v2"
-	"github.com/zeromicro/go-zero/core/lang"
 	"github.com/zeromicro/go-zero/core/stores/redis"
 )
 
 // CreateRedis returns an in process redis.Redis.
-func CreateRedis() (r *redis.Redis, clean func(), err error) {
-	mr, err := miniredis.Run()
-	if err != nil {
-		return nil, nil, err
-	}
-
-	return redis.New(mr.Addr()), func() {
-		ch := make(chan lang.PlaceholderType)
-
-		go func() {
-			mr.Close()
-			close(ch)
-		}()
+func CreateRedis(t *testing.T) *redis.Redis {
+	r, _ := CreateRedisWithClean(t)
+	return r
+}
 
-		select {
-		case <-ch:
-		case <-time.After(time.Second):
-		}
-	}, nil
+// CreateRedisWithClean returns an in process redis.Redis and a clean function.
+func CreateRedisWithClean(t *testing.T) (r *redis.Redis, clean func()) {
+	mr := miniredis.RunT(t)
+	return redis.New(mr.Addr()), mr.Close
 }

+ 31 - 63
core/stores/sqlc/cachedsql_test.go

@@ -33,13 +33,11 @@ func init() {
 
 func TestCachedConn_GetCache(t *testing.T) {
 	resetStats()
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
 	var value string
-	err = c.GetCache("any", &value)
+	err := c.GetCache("any", &value)
 	assert.Equal(t, ErrNotFound, err)
 	r.Set("any", `"value"`)
 	err = c.GetCache("any", &value)
@@ -49,15 +47,13 @@ func TestCachedConn_GetCache(t *testing.T) {
 
 func TestStat(t *testing.T) {
 	resetStats()
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
 
 	for i := 0; i < 10; i++ {
 		var str string
-		err = c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v any) error {
+		err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v any) error {
 			*v.(*string) = "zero"
 			return nil
 		})
@@ -72,9 +68,7 @@ func TestStat(t *testing.T) {
 
 func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) {
 	resetStats()
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	c := NewConn(dummySqlConn{}, cache.CacheConf{
 		{
@@ -87,7 +81,7 @@ func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) {
 	}, cache.WithExpiry(time.Second*10))
 
 	var str string
-	err = c.QueryRowIndex(&str, "index", func(s any) string {
+	err := c.QueryRowIndex(&str, "index", func(s any) string {
 		return fmt.Sprintf("%s/1234", s)
 	}, func(conn sqlx.SqlConn, v any) (any, error) {
 		*v.(*string) = "zero"
@@ -121,16 +115,14 @@ func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) {
 
 func TestCachedConn_QueryRowIndex_HasCache(t *testing.T) {
 	resetStats()
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
 		cache.WithNotFoundExpiry(time.Second))
 
 	var str string
 	r.Set("index", `"primary"`)
-	err = c.QueryRowIndex(&str, "index", func(s any) string {
+	err := c.QueryRowIndex(&str, "index", func(s any) string {
 		return fmt.Sprintf("%s/1234", s)
 	}, func(conn sqlx.SqlConn, v any) (any, error) {
 		assert.Fail(t, "should not go here")
@@ -211,16 +203,14 @@ func TestCachedConn_QueryRowIndex_HasCache_IntPrimary(t *testing.T) {
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
 			resetStats()
-			r, clean, err := redistest.CreateRedis()
-			assert.Nil(t, err)
-			defer clean()
+			r := redistest.CreateRedis(t)
 
 			c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
 				cache.WithNotFoundExpiry(time.Second))
 
 			var str string
 			r.Set("index", test.primaryCache)
-			err = c.QueryRowIndex(&str, "index", func(s any) string {
+			err := c.QueryRowIndex(&str, "index", func(s any) string {
 				return fmt.Sprintf("%v/1234", s)
 			}, func(conn sqlx.SqlConn, v any) (any, error) {
 				assert.Fail(t, "should not go here")
@@ -251,16 +241,14 @@ func TestCachedConn_QueryRowIndex_HasWrongCache(t *testing.T) {
 	for k, v := range caches {
 		t.Run(k+"/"+v, func(t *testing.T) {
 			resetStats()
-			r, clean, err := redistest.CreateRedis()
-			assert.Nil(t, err)
-			defer clean()
+			r := redistest.CreateRedis(t)
 
 			c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
 				cache.WithNotFoundExpiry(time.Second))
 
 			var str string
 			r.Set(k, v)
-			err = c.QueryRowIndex(&str, "index", func(s any) string {
+			err := c.QueryRowIndex(&str, "index", func(s any) string {
 				return fmt.Sprintf("%s/1234", s)
 			}, func(conn sqlx.SqlConn, v any) (any, error) {
 				*v.(*string) = "xin"
@@ -306,15 +294,13 @@ func TestStatCacheFails(t *testing.T) {
 
 func TestStatDbFails(t *testing.T) {
 	resetStats()
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
 
 	for i := 0; i < 20; i++ {
 		var str string
-		err = c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v any) error {
+		err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v any) error {
 			return errors.New("db failed")
 		})
 		assert.NotNil(t, err)
@@ -327,9 +313,7 @@ func TestStatDbFails(t *testing.T) {
 
 func TestStatFromMemory(t *testing.T) {
 	resetStats()
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
 
@@ -385,9 +369,7 @@ func TestStatFromMemory(t *testing.T) {
 }
 
 func TestCachedConnQueryRow(t *testing.T) {
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	const (
 		key   = "user"
@@ -397,7 +379,7 @@ func TestCachedConnQueryRow(t *testing.T) {
 	var user string
 	var ran bool
 	c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
-	err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v any) error {
+	err := c.QueryRow(&user, key, func(conn sqlx.SqlConn, v any) error {
 		ran = true
 		user = value
 		return nil
@@ -413,9 +395,7 @@ func TestCachedConnQueryRow(t *testing.T) {
 }
 
 func TestCachedConnQueryRowFromCache(t *testing.T) {
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	const (
 		key   = "user"
@@ -426,7 +406,7 @@ func TestCachedConnQueryRowFromCache(t *testing.T) {
 	var ran bool
 	c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
 	assert.Nil(t, c.SetCache(key, value))
-	err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v any) error {
+	err := c.QueryRow(&user, key, func(conn sqlx.SqlConn, v any) error {
 		ran = true
 		user = value
 		return nil
@@ -442,9 +422,7 @@ func TestCachedConnQueryRowFromCache(t *testing.T) {
 }
 
 func TestQueryRowNotFound(t *testing.T) {
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	const key = "user"
 	var conn trackedConn
@@ -452,7 +430,7 @@ func TestQueryRowNotFound(t *testing.T) {
 	var ran int
 	c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
 	for i := 0; i < 20; i++ {
-		err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v any) error {
+		err := c.QueryRow(&user, key, func(conn sqlx.SqlConn, v any) error {
 			ran++
 			return sql.ErrNoRows
 		})
@@ -462,13 +440,11 @@ func TestQueryRowNotFound(t *testing.T) {
 }
 
 func TestCachedConnExec(t *testing.T) {
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	var conn trackedConn
 	c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
-	_, err = c.ExecNoCache("delete from user_table where id='kevin'")
+	_, err := c.ExecNoCache("delete from user_table where id='kevin'")
 	assert.Nil(t, err)
 	assert.True(t, conn.execValue)
 }
@@ -514,26 +490,22 @@ func TestCachedConnExecDropCacheFailed(t *testing.T) {
 }
 
 func TestCachedConnQueryRows(t *testing.T) {
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	var conn trackedConn
 	c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
 	var users []string
-	err = c.QueryRowsNoCache(&users, "select user from user_table where id='kevin'")
+	err := c.QueryRowsNoCache(&users, "select user from user_table where id='kevin'")
 	assert.Nil(t, err)
 	assert.True(t, conn.queryRowsValue)
 }
 
 func TestCachedConnTransact(t *testing.T) {
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	var conn trackedConn
 	c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
-	err = c.Transact(func(session sqlx.Session) error {
+	err := c.Transact(func(session sqlx.Session) error {
 		return nil
 	})
 	assert.Nil(t, err)
@@ -541,9 +513,7 @@ func TestCachedConnTransact(t *testing.T) {
 }
 
 func TestQueryRowNoCache(t *testing.T) {
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	const (
 		key   = "user"
@@ -557,20 +527,18 @@ func TestQueryRowNoCache(t *testing.T) {
 		return nil
 	}}
 	c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
-	err = c.QueryRowNoCache(&user, key)
+	err := c.QueryRowNoCache(&user, key)
 	assert.Nil(t, err)
 	assert.Equal(t, value, user)
 	assert.True(t, ran)
 }
 
 func TestNewConnWithCache(t *testing.T) {
-	r, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	r := redistest.CreateRedis(t)
 
 	var conn trackedConn
 	c := NewConnWithCache(&conn, cache.NewNode(r, singleFlights, stats, sql.ErrNoRows))
-	_, err = c.ExecNoCache("delete from user_table where id='kevin'")
+	_, err := c.ExecNoCache("delete from user_table where id='kevin'")
 	assert.Nil(t, err)
 	assert.True(t, conn.execValue)
 }

+ 1 - 3
zrpc/internal/auth/auth_test.go

@@ -43,9 +43,7 @@ func TestAuthenticator(t *testing.T) {
 		},
 	}
 
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {

+ 2 - 6
zrpc/internal/serverinterceptors/authinterceptor_test.go

@@ -45,9 +45,7 @@ func TestStreamAuthorizeInterceptor(t *testing.T) {
 		},
 	}
 
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
@@ -111,9 +109,7 @@ func TestUnaryAuthorizeInterceptor(t *testing.T) {
 		},
 	}
 
-	store, clean, err := redistest.CreateRedis()
-	assert.Nil(t, err)
-	defer clean()
+	store := redistest.CreateRedis(t)
 
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {