فهرست منبع

feat: add MustNewRedis (#2824)

* feat: add MustNewRedis

* feat: add MustNewRedis

* feat: add MustNewRedis

* x

* x

* fix ut

* x

* x

* x

* x

* x
MarkJoyMa 2 سال پیش
والد
کامیت
fde05ccb28

+ 3 - 2
core/stores/cache/cache.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/zeromicro/go-zero/core/errorx"
 	"github.com/zeromicro/go-zero/core/hash"
+	"github.com/zeromicro/go-zero/core/stores/redis"
 	"github.com/zeromicro/go-zero/core/syncx"
 )
 
@@ -62,12 +63,12 @@ func New(c ClusterConf, barrier syncx.SingleFlight, st *Stat, errNotFound error,
 	}
 
 	if len(c) == 1 {
-		return NewNode(c[0].NewRedis(), barrier, st, errNotFound, opts...)
+		return NewNode(redis.MustNewRedis(c[0].RedisConf), barrier, st, errNotFound, opts...)
 	}
 
 	dispatcher := hash.NewConsistentHash()
 	for _, node := range c {
-		cn := NewNode(node.NewRedis(), barrier, st, errNotFound, opts...)
+		cn := NewNode(redis.MustNewRedis(node.RedisConf), barrier, st, errNotFound, opts...)
 		dispatcher.AddWithWeight(cn, node.Weight)
 	}
 

+ 2 - 2
core/stores/cache/cache_test.go

@@ -163,12 +163,10 @@ func TestCache_SetDel(t *testing.T) {
 		r1, err := miniredis.Run()
 		assert.NoError(t, err)
 		defer r1.Close()
-		r1.SetError("mock error")
 
 		r2, err := miniredis.Run()
 		assert.NoError(t, err)
 		defer r2.Close()
-		r2.SetError("mock error")
 
 		conf := ClusterConf{
 			{
@@ -187,6 +185,8 @@ func TestCache_SetDel(t *testing.T) {
 			},
 		}
 		c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
+		r1.SetError("mock error")
+		r2.SetError("mock error")
 		assert.NoError(t, c.Del("a", "b", "c"))
 	})
 }

+ 1 - 1
core/stores/kv/store.go

@@ -164,7 +164,7 @@ func NewStore(c KvConf) Store {
 	// because Store and redis.Redis has different methods.
 	dispatcher := hash.NewConsistentHash()
 	for _, node := range c {
-		cn := node.NewRedis()
+		cn := redis.MustNewRedis(node.RedisConf)
 		dispatcher.AddWithWeight(cn, node.Weight)
 	}
 

+ 3 - 1
core/stores/redis/conf.go

@@ -9,6 +9,8 @@ var (
 	ErrEmptyType = errors.New("empty redis type")
 	// ErrEmptyKey is an error that indicates no redis key is set.
 	ErrEmptyKey = errors.New("empty redis key")
+	// ErrPing is an error that indicates ping failed.
+	ErrPing = errors.New("ping redis failed")
 )
 
 type (
@@ -27,7 +29,7 @@ type (
 	}
 )
 
-// NewRedis returns a Redis.
+// Deprecated: use MustNewRedis or NewRedis instead.
 func (rc RedisConf) NewRedis() *Redis {
 	var opts []Option
 	if rc.Type == ClusterType {

+ 40 - 1
core/stores/redis/redis.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	"log"
 	"strconv"
 	"time"
 
@@ -85,8 +86,46 @@ type (
 	StringCmd = red.StringCmd
 )
 
-// New returns a Redis with given options.
+// Deprecated: use MustNewRedis or NewRedis instead.
 func New(addr string, opts ...Option) *Redis {
+	return newRedis(addr, opts...)
+}
+
+// MustNewRedis returns a Redis with given options.
+func MustNewRedis(conf RedisConf, opts ...Option) *Redis {
+	rds, err := NewRedis(conf, opts...)
+	if err != nil {
+		log.Fatal(err)
+	}
+
+	return rds
+}
+
+// NewRedis returns a Redis with given options.
+func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
+	if err := conf.Validate(); err != nil {
+		return nil, err
+	}
+
+	if conf.Type == ClusterType {
+		opts = append([]Option{Cluster()}, opts...)
+	}
+	if len(conf.Pass) > 0 {
+		opts = append([]Option{WithPass(conf.Pass)}, opts...)
+	}
+	if conf.Tls {
+		opts = append([]Option{WithTLS()}, opts...)
+	}
+
+	rds := newRedis(conf.Host, opts...)
+	if !rds.Ping() {
+		return nil, ErrPing
+	}
+
+	return rds, nil
+}
+
+func newRedis(addr string, opts ...Option) *Redis {
 	r := &Redis{
 		Addr: addr,
 		Type: NodeType,

+ 110 - 0
core/stores/redis/redis_test.go

@@ -16,6 +16,116 @@ import (
 	"github.com/zeromicro/go-zero/core/stringx"
 )
 
+func TestNewRedis(t *testing.T) {
+	r1, err := miniredis.Run()
+	assert.NoError(t, err)
+	defer r1.Close()
+
+	r2, err := miniredis.Run()
+	assert.NoError(t, err)
+	defer r2.Close()
+	r2.SetError("mock")
+
+	tests := []struct {
+		name string
+		RedisConf
+		ok       bool
+		redisErr bool
+	}{
+		{
+			name: "missing host",
+			RedisConf: RedisConf{
+				Host: "",
+				Type: NodeType,
+				Pass: "",
+			},
+			ok: false,
+		},
+		{
+			name: "missing type",
+			RedisConf: RedisConf{
+				Host: "localhost:6379",
+				Type: "",
+				Pass: "",
+			},
+			ok: false,
+		},
+		{
+			name: "ok",
+			RedisConf: RedisConf{
+				Host: r1.Addr(),
+				Type: NodeType,
+				Pass: "",
+			},
+			ok: true,
+		},
+		{
+			name: "ok",
+			RedisConf: RedisConf{
+				Host: r1.Addr(),
+				Type: ClusterType,
+				Pass: "",
+			},
+			ok: true,
+		},
+		{
+			name: "password",
+			RedisConf: RedisConf{
+				Host: r1.Addr(),
+				Type: NodeType,
+				Pass: "pw",
+			},
+			ok: true,
+		},
+		{
+			name: "tls",
+			RedisConf: RedisConf{
+				Host: r1.Addr(),
+				Type: NodeType,
+				Tls:  true,
+			},
+			ok: true,
+		},
+		{
+			name: "node error",
+			RedisConf: RedisConf{
+				Host: r2.Addr(),
+				Type: NodeType,
+				Pass: "",
+			},
+			ok:       true,
+			redisErr: true,
+		},
+		{
+			name: "cluster error",
+			RedisConf: RedisConf{
+				Host: r2.Addr(),
+				Type: ClusterType,
+				Pass: "",
+			},
+			ok:       true,
+			redisErr: true,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(stringx.RandId(), func(t *testing.T) {
+			rds, err := NewRedis(test.RedisConf)
+			if test.ok {
+				if test.redisErr {
+					assert.Error(t, err)
+					assert.Nil(t, rds)
+				} else {
+					assert.NoError(t, err)
+					assert.NotNil(t, rds)
+				}
+			} else {
+				assert.Error(t, err)
+			}
+		})
+	}
+}
+
 func TestRedis_Decr(t *testing.T) {
 	runOnRedis(t, func(client *Redis) {
 		_, err := New(client.Addr, badType()).Decr("a")

+ 7 - 1
zrpc/server.go

@@ -7,6 +7,7 @@ import (
 	"github.com/zeromicro/go-zero/core/load"
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/stat"
+	"github.com/zeromicro/go-zero/core/stores/redis"
 	"github.com/zeromicro/go-zero/zrpc/internal"
 	"github.com/zeromicro/go-zero/zrpc/internal/auth"
 	"github.com/zeromicro/go-zero/zrpc/internal/serverinterceptors"
@@ -120,7 +121,12 @@ func setupInterceptors(server internal.Server, c RpcServerConf, metrics *stat.Me
 	}
 
 	if c.Auth {
-		authenticator, err := auth.NewAuthenticator(c.Redis.NewRedis(), c.Redis.Key, c.StrictControl)
+		rds, err := redis.NewRedis(c.Redis.RedisConf)
+		if err != nil {
+			return err
+		}
+
+		authenticator, err := auth.NewAuthenticator(rds, c.Redis.Key, c.StrictControl)
 		if err != nil {
 			return err
 		}

+ 13 - 3
zrpc/server_test.go

@@ -4,6 +4,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/alicebob/miniredis/v2"
 	"github.com/stretchr/testify/assert"
 	"github.com/zeromicro/go-zero/core/discov"
 	"github.com/zeromicro/go-zero/core/logx"
@@ -16,12 +17,16 @@ import (
 )
 
 func TestServer_setupInterceptors(t *testing.T) {
+	rds, err := miniredis.Run()
+	assert.NoError(t, err)
+	defer rds.Close()
+
 	server := new(mockedServer)
-	err := setupInterceptors(server, RpcServerConf{
+	conf := RpcServerConf{
 		Auth: true,
 		Redis: redis.RedisKeyConf{
 			RedisConf: redis.RedisConf{
-				Host: "any",
+				Host: rds.Addr(),
 				Type: redis.NodeType,
 			},
 			Key: "foo",
@@ -35,10 +40,15 @@ func TestServer_setupInterceptors(t *testing.T) {
 			Prometheus: true,
 			Breaker:    true,
 		},
-	}, new(stat.Metrics))
+	}
+	err = setupInterceptors(server, conf, new(stat.Metrics))
 	assert.Nil(t, err)
 	assert.Equal(t, 3, len(server.unaryInterceptors))
 	assert.Equal(t, 1, len(server.streamInterceptors))
+
+	rds.SetError("mock error")
+	err = setupInterceptors(server, conf, new(stat.Metrics))
+	assert.Error(t, err)
 }
 
 func TestServer(t *testing.T) {