浏览代码

feat: unique redis addrs and trim spaces (#3004)

Kevin Wan 2 年之前
父节点
当前提交
7a0c04bc21

+ 1 - 2
core/stores/redis/redisblockingnode.go

@@ -2,7 +2,6 @@ package redis
 
 import (
 	"fmt"
-	"strings"
 
 	red "github.com/go-redis/redis/v8"
 	"github.com/zeromicro/go-zero/core/logx"
@@ -32,7 +31,7 @@ func CreateBlockingNode(r *Redis) (ClosableNode, error) {
 		return &clientBridge{client}, nil
 	case ClusterType:
 		client := red.NewClusterClient(&red.ClusterOptions{
-			Addrs:        strings.Split(r.Addr, ","),
+			Addrs:        splitClusterAddrs(r.Addr),
 			Password:     r.Pass,
 			MaxRetries:   maxRetries,
 			PoolSize:     1,

+ 18 - 1
core/stores/redis/redisclustermanager.go

@@ -9,6 +9,8 @@ import (
 	"github.com/zeromicro/go-zero/core/syncx"
 )
 
+const addrSep = ","
+
 var clusterManager = syncx.NewResourceManager()
 
 func getCluster(r *Redis) (*red.ClusterClient, error) {
@@ -20,7 +22,7 @@ func getCluster(r *Redis) (*red.ClusterClient, error) {
 			}
 		}
 		store := red.NewClusterClient(&red.ClusterOptions{
-			Addrs:        strings.Split(r.Addr, ","),
+			Addrs:        splitClusterAddrs(r.Addr),
 			Password:     r.Pass,
 			MaxRetries:   maxRetries,
 			MinIdleConns: idleConns,
@@ -36,3 +38,18 @@ func getCluster(r *Redis) (*red.ClusterClient, error) {
 
 	return val.(*red.ClusterClient), nil
 }
+
+func splitClusterAddrs(addr string) []string {
+	addrs := strings.Split(addr, addrSep)
+	unique := make(map[string]struct{})
+	for _, each := range addrs {
+		unique[strings.TrimSpace(each)] = struct{}{}
+	}
+
+	addrs = addrs[:0]
+	for k := range unique {
+		addrs = append(addrs, k)
+	}
+
+	return addrs
+}

+ 43 - 0
core/stores/redis/redisclustermanager_test.go

@@ -0,0 +1,43 @@
+package redis
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestSplitClusterAddrs(t *testing.T) {
+	testCases := []struct {
+		name     string
+		input    string
+		expected []string
+	}{
+		{
+			name:     "empty input",
+			input:    "",
+			expected: []string{""},
+		},
+		{
+			name:     "single address",
+			input:    "127.0.0.1:8000",
+			expected: []string{"127.0.0.1:8000"},
+		},
+		{
+			name:     "multiple addresses with duplicates",
+			input:    "127.0.0.1:8000,127.0.0.1:8001, 127.0.0.1:8000",
+			expected: []string{"127.0.0.1:8000", "127.0.0.1:8001"},
+		},
+		{
+			name:     "multiple addresses without duplicates",
+			input:    "127.0.0.1:8000, 127.0.0.1:8001, 127.0.0.1:8002",
+			expected: []string{"127.0.0.1:8000", "127.0.0.1:8001", "127.0.0.1:8002"},
+		},
+	}
+
+	for _, tc := range testCases {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			assert.ElementsMatch(t, tc.expected, splitClusterAddrs(tc.input))
+		})
+	}
+}