Преглед изворни кода

redis增加tls支持 (#595)

* redis连接增加支持tls选项

* 优化redis tls config 写法

* redis增加tls支持

* 增加redis tls测试用例,但redis tls local server不支持,测试用例全部NotNil

Co-authored-by: liuyi <liuyi@fangyb.com>
Co-authored-by: yi.liu <yi.liu@xshoppy.com>
r00mz пре 4 година
родитељ
комит
8cb6490724

+ 7 - 3
core/stores/redis/conf.go

@@ -14,9 +14,10 @@ var (
 type (
 	// A RedisConf is a redis config.
 	RedisConf struct {
-		Host string
-		Type string `json:",default=node,options=node|cluster"`
-		Pass string `json:",optional"`
+		Host    string
+		Type    string `json:",default=node,options=node|cluster"`
+		Pass    string `json:",optional"`
+		TLSFlag bool   `json:",default=false,options=true|false"`
 	}
 
 	// A RedisKeyConf is a redis config with key.
@@ -28,6 +29,9 @@ type (
 
 // NewRedis returns a Redis.
 func (rc RedisConf) NewRedis() *Redis {
+	if rc.TLSFlag {
+		return NewRedisWithTLS(rc.Host, rc.Type, rc.TLSFlag, rc.Pass)
+	}
 	return NewRedis(rc.Host, rc.Type, rc.Pass)
 }
 

+ 24 - 10
core/stores/redis/redis.go

@@ -37,10 +37,11 @@ type (
 
 	// Redis defines a redis node/cluster. It is thread-safe.
 	Redis struct {
-		Addr string
-		Type string
-		Pass string
-		brk  breaker.Breaker
+		Addr    string
+		Type    string
+		Pass    string
+		brk     breaker.Breaker
+		TLSFlag bool
 	}
 
 	// RedisNode interface represents a redis node.
@@ -71,16 +72,21 @@ type (
 
 // NewRedis returns a Redis.
 func NewRedis(redisAddr, redisType string, redisPass ...string) *Redis {
+	return NewRedisWithTLS(redisAddr, redisType, false, redisPass...)
+}
+
+func NewRedisWithTLS(redisAddr, redisType string, tlsFlag bool, redisPass ...string) *Redis {
 	var pass string
 	for _, v := range redisPass {
 		pass = v
 	}
 
 	return &Redis{
-		Addr: redisAddr,
-		Type: redisType,
-		Pass: pass,
-		brk:  breaker.NewBreaker(),
+		Addr:    redisAddr,
+		Type:    redisType,
+		Pass:    pass,
+		brk:     breaker.NewBreaker(),
+		TLSFlag: tlsFlag,
 	}
 }
 
@@ -1704,9 +1710,17 @@ func acceptable(err error) bool {
 func getRedis(r *Redis) (RedisNode, error) {
 	switch r.Type {
 	case ClusterType:
-		return getCluster(r.Addr, r.Pass)
+		if r.TLSFlag {
+			return getClusterWithTLS(r.Addr, r.Pass, r.TLSFlag)
+		} else {
+			return getCluster(r.Addr, r.Pass)
+		}
 	case NodeType:
-		return getClient(r.Addr, r.Pass)
+		if r.TLSFlag {
+			return getClientWithTLS(r.Addr, r.Pass, r.TLSFlag)
+		} else {
+			return getClient(r.Addr, r.Pass)
+		}
 	default:
 		return nil, fmt.Errorf("redis type '%s' is not supported", r.Type)
 	}

+ 36 - 1
core/stores/redis/redis_test.go

@@ -1,6 +1,7 @@
 package redis
 
 import (
+	"crypto/tls"
 	"errors"
 	"io"
 	"strconv"
@@ -26,6 +27,20 @@ func TestRedis_Exists(t *testing.T) {
 	})
 }
 
+func TestRedisTLS_Exists(t *testing.T) {
+	runOnRedisTLS(t, func(client *Redis) {
+		_, err := NewRedisWithTLS(client.Addr, "", true).Exists("a")
+		assert.NotNil(t, err)
+		ok, err := client.Exists("a")
+		assert.NotNil(t, err)
+		assert.False(t, ok)
+		assert.NotNil(t, client.Set("a", "b"))
+		ok, err = client.Exists("a")
+		assert.NotNil(t, err)
+		assert.False(t, ok)
+	})
+}
+
 func TestRedis_Eval(t *testing.T) {
 	runOnRedis(t, func(client *Redis) {
 		_, err := NewRedis(client.Addr, "").Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
@@ -1062,8 +1077,28 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) {
 			client.Close()
 		}
 	}()
-
 	fn(NewRedis(s.Addr(), NodeType))
+
+}
+
+func runOnRedisTLS(t *testing.T, fn func(client *Redis)) {
+	s, err := miniredis.RunTLS(&tls.Config{
+		Certificates:       make([]tls.Certificate, 1),
+		InsecureSkipVerify: true,
+	})
+	assert.Nil(t, err)
+	defer func() {
+		client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
+			return nil, errors.New("should already exist")
+		})
+		if err != nil {
+			t.Error(err)
+		}
+		if client != nil {
+			client.Close()
+		}
+	}()
+	fn(NewRedisWithTLS(s.Addr(), NodeType, true))
 }
 
 type mockedNode struct {

+ 12 - 0
core/stores/redis/redisclientmanager.go

@@ -1,6 +1,7 @@
 package redis
 
 import (
+	"crypto/tls"
 	"io"
 
 	red "github.com/go-redis/redis"
@@ -16,13 +17,24 @@ const (
 var clientManager = syncx.NewResourceManager()
 
 func getClient(server, pass string) (*red.Client, error) {
+	return getClientWithTLS(server, pass, false)
+}
+
+func getClientWithTLS(server, pass string, tlsFlag bool) (*red.Client, error) {
 	val, err := clientManager.GetResource(server, func() (io.Closer, error) {
+		var tlsConfig *tls.Config = nil
+		if tlsFlag {
+			tlsConfig = &tls.Config{
+				InsecureSkipVerify: true,
+			}
+		}
 		store := red.NewClient(&red.Options{
 			Addr:         server,
 			Password:     pass,
 			DB:           defaultDatabase,
 			MaxRetries:   maxRetries,
 			MinIdleConns: idleConns,
+			TLSConfig:    tlsConfig,
 		})
 		store.WrapProcess(process)
 		return store, nil

+ 12 - 0
core/stores/redis/redisclustermanager.go

@@ -1,6 +1,7 @@
 package redis
 
 import (
+	"crypto/tls"
 	"io"
 
 	red "github.com/go-redis/redis"
@@ -10,12 +11,23 @@ import (
 var clusterManager = syncx.NewResourceManager()
 
 func getCluster(server, pass string) (*red.ClusterClient, error) {
+	return getClusterWithTLS(server, pass, false)
+}
+
+func getClusterWithTLS(server, pass string, tlsFlag bool) (*red.ClusterClient, error) {
 	val, err := clusterManager.GetResource(server, func() (io.Closer, error) {
+		var tlsConfig *tls.Config = nil
+		if tlsFlag {
+			tlsConfig = &tls.Config{
+				InsecureSkipVerify: true,
+			}
+		}
 		store := red.NewClusterClient(&red.ClusterOptions{
 			Addrs:        []string{server},
 			Password:     pass,
 			MaxRetries:   maxRetries,
 			MinIdleConns: idleConns,
+			TLSConfig:    tlsConfig,
 		})
 		store.WrapProcess(process)