瀏覽代碼

feat: support auth account for etcd (#1174)

Kevin Wan 3 年之前
父節點
當前提交
59b9687f31

+ 7 - 0
core/discov/accountregistry.go

@@ -0,0 +1,7 @@
+package discov
+
+import "github.com/tal-tech/go-zero/core/discov/internal"
+
+func RegisterAccount(endpoints []string, user, pass string) {
+	internal.AddAccount(endpoints, user, pass)
+}

+ 21 - 0
core/discov/accountregistry_test.go

@@ -0,0 +1,21 @@
+package discov
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/core/discov/internal"
+	"github.com/tal-tech/go-zero/core/stringx"
+)
+
+func TestRegisterAccount(t *testing.T) {
+	endpoints := []string{
+		"localhost:2379",
+	}
+	user := "foo" + stringx.Rand()
+	RegisterAccount(endpoints, user, "bar")
+	account, ok := internal.GetAccount(endpoints)
+	assert.True(t, ok)
+	assert.Equal(t, user, account.User)
+	assert.Equal(t, "bar", account.Pass)
+}

+ 7 - 0
core/discov/config.go

@@ -6,6 +6,13 @@ import "errors"
 type EtcdConf struct {
 	Hosts []string
 	Key   string
+	User  string `json:",optional"`
+	Pass  string `json:",optional"`
+}
+
+// HasAccount returns if account provided.
+func (c EtcdConf) HasAccount() bool {
+	return len(c.User) > 0 && len(c.Pass) > 0
 }
 
 // Validate validates c.

+ 36 - 0
core/discov/config_test.go

@@ -44,3 +44,39 @@ func TestConfig(t *testing.T) {
 		}
 	}
 }
+
+func TestEtcdConf_HasAccount(t *testing.T) {
+	tests := []struct {
+		EtcdConf
+		hasAccount bool
+	}{
+		{
+			EtcdConf: EtcdConf{
+				Hosts: []string{"any"},
+				Key:   "key",
+			},
+			hasAccount: false,
+		},
+		{
+			EtcdConf: EtcdConf{
+				Hosts: []string{"any"},
+				Key:   "key",
+				User:  "foo",
+			},
+			hasAccount: false,
+		},
+		{
+			EtcdConf: EtcdConf{
+				Hosts: []string{"any"},
+				Key:   "key",
+				User:  "foo",
+				Pass:  "bar",
+			},
+			hasAccount: true,
+		},
+	}
+
+	for _, test := range tests {
+		assert.Equal(t, test.hasAccount, test.EtcdConf.HasAccount())
+	}
+}

+ 31 - 0
core/discov/internal/accountmanager.go

@@ -0,0 +1,31 @@
+package internal
+
+import "sync"
+
+type Account struct {
+	User string
+	Pass string
+}
+
+var (
+	accounts = make(map[string]Account)
+	lock     sync.RWMutex
+)
+
+func AddAccount(endpoints []string, user, pass string) {
+	lock.Lock()
+	defer lock.Unlock()
+
+	accounts[getClusterKey(endpoints)] = Account{
+		User: user,
+		Pass: pass,
+	}
+}
+
+func GetAccount(endpoints []string) (Account, bool) {
+	lock.RLock()
+	defer lock.RUnlock()
+
+	account, ok := accounts[getClusterKey(endpoints)]
+	return account, ok
+}

+ 34 - 0
core/discov/internal/accountmanager_test.go

@@ -0,0 +1,34 @@
+package internal
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/core/stringx"
+)
+
+func TestAccount(t *testing.T) {
+	endpoints := []string{
+		"192.168.0.2:2379",
+		"192.168.0.3:2379",
+		"192.168.0.4:2379",
+	}
+	username := "foo" + stringx.Rand()
+	password := "bar"
+	anotherPassword := "any"
+
+	_, ok := GetAccount(endpoints)
+	assert.False(t, ok)
+
+	AddAccount(endpoints, username, password)
+	account, ok := GetAccount(endpoints)
+	assert.True(t, ok)
+	assert.Equal(t, username, account.User)
+	assert.Equal(t, password, account.Pass)
+
+	AddAccount(endpoints, username, anotherPassword)
+	account, ok = GetAccount(endpoints)
+	assert.True(t, ok)
+	assert.Equal(t, username, account.User)
+	assert.Equal(t, anotherPassword, account.Pass)
+}

+ 8 - 2
core/discov/internal/registry.go

@@ -302,14 +302,20 @@ func (c *cluster) watchConnState(cli EtcdClient) {
 
 // DialClient dials an etcd cluster with given endpoints.
 func DialClient(endpoints []string) (EtcdClient, error) {
-	return clientv3.New(clientv3.Config{
+	cfg := clientv3.Config{
 		Endpoints:            endpoints,
 		AutoSyncInterval:     autoSyncInterval,
 		DialTimeout:          DialTimeout,
 		DialKeepAliveTime:    dialKeepAliveTime,
 		DialKeepAliveTimeout: DialTimeout,
 		RejectOldCluster:     true,
-	})
+	}
+	if account, ok := GetAccount(endpoints); ok {
+		cfg.Username = account.User
+		cfg.Password = account.Pass
+	}
+
+	return clientv3.New(cfg)
 }
 
 func getClusterKey(endpoints []string) string {

+ 1 - 0
core/discov/internal/registry_test.go

@@ -33,6 +33,7 @@ func setMockClient(cli EtcdClient) func() {
 }
 
 func TestGetCluster(t *testing.T) {
+	AddAccount([]string{"first"}, "foo", "bar")
 	c1 := GetRegistry().getCluster([]string{"first"})
 	c2 := GetRegistry().getCluster([]string{"second"})
 	c3 := GetRegistry().getCluster([]string{"first"})

+ 11 - 4
core/discov/publisher.go

@@ -11,8 +11,8 @@ import (
 )
 
 type (
-	// PublisherOption defines the method to customize a Publisher.
-	PublisherOption func(client *Publisher)
+	// PubOption defines the method to customize a Publisher.
+	PubOption func(client *Publisher)
 
 	// A Publisher can be used to publish the value to an etcd cluster on the given key.
 	Publisher struct {
@@ -32,7 +32,7 @@ type (
 // endpoints is the hosts of the etcd cluster.
 // key:value are a pair to be published.
 // opts are used to customize the Publisher.
-func NewPublisher(endpoints []string, key, value string, opts ...PublisherOption) *Publisher {
+func NewPublisher(endpoints []string, key, value string, opts ...PubOption) *Publisher {
 	publisher := &Publisher{
 		endpoints:  endpoints,
 		key:        key,
@@ -145,8 +145,15 @@ func (p *Publisher) revoke(cli internal.EtcdClient) {
 	}
 }
 
+// WithPubEtcdAccount provides the etcd username/password.
+func WithPubEtcdAccount(user, pass string) PubOption {
+	return func(pub *Publisher) {
+		internal.AddAccount(pub.endpoints, user, pass)
+	}
+}
+
 // WithId customizes a Publisher with the id.
-func WithId(id int64) PublisherOption {
+func WithId(id int64) PubOption {
 	return func(publisher *Publisher) {
 		publisher.id = id
 	}

+ 3 - 1
core/discov/publisher_test.go

@@ -11,6 +11,7 @@ import (
 	"github.com/tal-tech/go-zero/core/discov/internal"
 	"github.com/tal-tech/go-zero/core/lang"
 	"github.com/tal-tech/go-zero/core/logx"
+	"github.com/tal-tech/go-zero/core/stringx"
 	clientv3 "go.etcd.io/etcd/client/v3"
 )
 
@@ -30,7 +31,8 @@ func TestPublisher_register(t *testing.T) {
 		ID: id,
 	}, nil)
 	cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any())
-	pub := NewPublisher(nil, "thekey", "thevalue")
+	pub := NewPublisher(nil, "thekey", "thevalue",
+		WithPubEtcdAccount(stringx.Rand(), "bar"))
 	_, err := pub.register(cli)
 	assert.Nil(t, err)
 }

+ 17 - 13
core/discov/subscriber.go

@@ -9,16 +9,14 @@ import (
 )
 
 type (
-	subOptions struct {
-		exclusive bool
-	}
-
 	// SubOption defines the method to customize a Subscriber.
-	SubOption func(opts *subOptions)
+	SubOption func(sub *Subscriber)
 
 	// A Subscriber is used to subscribe the given key on a etcd cluster.
 	Subscriber struct {
-		items *container
+		endpoints []string
+		exclusive bool
+		items     *container
 	}
 )
 
@@ -27,14 +25,14 @@ type (
 // key is the key to subscribe.
 // opts are used to customize the Subscriber.
 func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscriber, error) {
-	var subOpts subOptions
+	sub := &Subscriber{
+		endpoints: endpoints,
+	}
 	for _, opt := range opts {
-		opt(&subOpts)
+		opt(sub)
 	}
+	sub.items = newContainer(sub.exclusive)
 
-	sub := &Subscriber{
-		items: newContainer(subOpts.exclusive),
-	}
 	if err := internal.GetRegistry().Monitor(endpoints, key, sub.items); err != nil {
 		return nil, err
 	}
@@ -55,8 +53,14 @@ func (s *Subscriber) Values() []string {
 // Exclusive means that key value can only be 1:1,
 // which means later added value will remove the keys associated with the same value previously.
 func Exclusive() SubOption {
-	return func(opts *subOptions) {
-		opts.exclusive = true
+	return func(sub *Subscriber) {
+		sub.exclusive = true
+	}
+}
+
+func WithSubEtcdAccount(user, pass string) SubOption {
+	return func(sub *Subscriber) {
+		internal.AddAccount(sub.endpoints, user, pass)
 	}
 }
 

+ 15 - 4
core/discov/subscriber_test.go

@@ -6,6 +6,7 @@ import (
 
 	"github.com/stretchr/testify/assert"
 	"github.com/tal-tech/go-zero/core/discov/internal"
+	"github.com/tal-tech/go-zero/core/stringx"
 )
 
 const (
@@ -201,11 +202,9 @@ func TestContainer(t *testing.T) {
 }
 
 func TestSubscriber(t *testing.T) {
-	var opt subOptions
-	Exclusive()(&opt)
-
 	sub := new(Subscriber)
-	sub.items = newContainer(opt.exclusive)
+	Exclusive()(sub)
+	sub.items = newContainer(sub.exclusive)
 	var count int32
 	sub.AddListener(func() {
 		atomic.AddInt32(&count, 1)
@@ -214,3 +213,15 @@ func TestSubscriber(t *testing.T) {
 	assert.Empty(t, sub.Values())
 	assert.Equal(t, int32(1), atomic.LoadInt32(&count))
 }
+
+func TestWithSubEtcdAccount(t *testing.T) {
+	endpoints := []string{"localhost:2379"}
+	user := stringx.Rand()
+	WithSubEtcdAccount(user, "bar")(&Subscriber{
+		endpoints: endpoints,
+	})
+	account, ok := internal.GetAccount(endpoints)
+	assert.True(t, ok)
+	assert.Equal(t, user, account.User)
+	assert.Equal(t, "bar", account.Pass)
+}

+ 5 - 0
zrpc/client.go

@@ -4,6 +4,7 @@ import (
 	"log"
 	"time"
 
+	"github.com/tal-tech/go-zero/core/discov"
 	"github.com/tal-tech/go-zero/zrpc/internal"
 	"github.com/tal-tech/go-zero/zrpc/internal/auth"
 	"google.golang.org/grpc"
@@ -74,6 +75,10 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
 			return nil, err
 		}
 
+		if c.Etcd.HasAccount() {
+			discov.RegisterAccount(c.Etcd.Hosts, c.Etcd.User, c.Etcd.Pass)
+		}
+
 		target = internal.BuildDiscovTarget(c.Etcd.Hosts, c.Etcd.Key)
 	}
 

+ 6 - 3
zrpc/internal/rpcpubserver.go

@@ -14,11 +14,14 @@ const (
 )
 
 // NewRpcPubServer returns a Server.
-func NewRpcPubServer(etcdEndpoints []string, etcdKey, listenOn string,
-	opts ...ServerOption) (Server, error) {
+func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, opts ...ServerOption) (Server, error) {
 	registerEtcd := func() error {
 		pubListenOn := figureOutListenOn(listenOn)
-		pubClient := discov.NewPublisher(etcdEndpoints, etcdKey, pubListenOn)
+		var pubOpts []discov.PubOption
+		if etcd.HasAccount() {
+			pubOpts = append(pubOpts, discov.WithPubEtcdAccount(etcd.User, etcd.Pass))
+		}
+		pubClient := discov.NewPublisher(etcd.Hosts, etcd.Key, pubListenOn, pubOpts...)
 		return pubClient.KeepAlive()
 	}
 	server := keepAliveServer{

+ 1 - 1
zrpc/server.go

@@ -41,7 +41,7 @@ func NewServer(c RpcServerConf, register internal.RegisterFn) (*RpcServer, error
 	serverOptions := []internal.ServerOption{internal.WithMetrics(metrics), internal.WithMaxRetries(c.MaxRetries)}
 
 	if c.HasEtcd() {
-		server, err = internal.NewRpcPubServer(c.Etcd.Hosts, c.Etcd.Key, c.ListenOn, serverOptions...)
+		server, err = internal.NewRpcPubServer(c.Etcd, c.ListenOn, serverOptions...)
 		if err != nil {
 			return nil, err
 		}