Pārlūkot izejas kodu

feat: support tls for etcd client (#1390)

* feat: support tls for etcd client

* chore: fix typo

* refactor: rename TrustedCAFile to CACertFile

* docs: add comments

* fix: missing tls registration

* feat: add InsecureSkipVerify config for testing
Kevin Wan 3 gadi atpakaļ
vecāks
revīzija
a7aeb8ac0e

+ 6 - 0
core/discov/accountregistry.go

@@ -6,3 +6,9 @@ import "github.com/tal-tech/go-zero/core/discov/internal"
 func RegisterAccount(endpoints []string, user, pass string) {
 	internal.AddAccount(endpoints, user, pass)
 }
+
+// RegisterTLS registers the CertFile/CertKeyFile/CACertFile to the given etcd.
+func RegisterTLS(endpoints []string, certFile, certKeyFile, caFile string,
+	insecureSkipVerify bool) error {
+	return internal.AddTLS(endpoints, certFile, certKeyFile, caFile, insecureSkipVerify)
+}

+ 13 - 4
core/discov/config.go

@@ -4,10 +4,14 @@ import "errors"
 
 // EtcdConf is the config item with the given key on etcd.
 type EtcdConf struct {
-	Hosts []string
-	Key   string
-	User  string `json:",optional"`
-	Pass  string `json:",optional"`
+	Hosts              []string
+	Key                string
+	User               string `json:",optional"`
+	Pass               string `json:",optional"`
+	CertFile           string `json:",optional"`
+	CertKeyFile        string `json:",optional=CertFile"`
+	CACertFile         string `json:",optional=CertFile"`
+	InsecureSkipVerify bool   `json:",optional"`
 }
 
 // HasAccount returns if account provided.
@@ -15,6 +19,11 @@ func (c EtcdConf) HasAccount() bool {
 	return len(c.User) > 0 && len(c.Pass) > 0
 }
 
+// HasTLS returns if TLS CertFile/CertKeyFile/CACertFile are provided.
+func (c EtcdConf) HasTLS() bool {
+	return len(c.CertFile) > 0 && len(c.CertKeyFile) > 0 && len(c.CACertFile) > 0
+}
+
 // Validate validates c.
 func (c EtcdConf) Validate() error {
 	if len(c.Hosts) == 0 {

+ 44 - 3
core/discov/internal/accountmanager.go

@@ -1,10 +1,16 @@
 package internal
 
-import "sync"
+import (
+	"crypto/tls"
+	"crypto/x509"
+	"io/ioutil"
+	"sync"
+)
 
 var (
-	accounts = make(map[string]Account)
-	lock     sync.RWMutex
+	accounts   = make(map[string]Account)
+	tlsConfigs = make(map[string]*tls.Config)
+	lock       sync.RWMutex
 )
 
 // Account holds the username/password for an etcd cluster.
@@ -24,6 +30,32 @@ func AddAccount(endpoints []string, user, pass string) {
 	}
 }
 
+// AddTLS adds the tls cert files for the given etcd cluster.
+func AddTLS(endpoints []string, certFile, certKeyFile, caFile string, insecureSkipVerify bool) error {
+	cert, err := tls.LoadX509KeyPair(certFile, certKeyFile)
+	if err != nil {
+		return err
+	}
+
+	caData, err := ioutil.ReadFile(caFile)
+	if err != nil {
+		return err
+	}
+
+	pool := x509.NewCertPool()
+	pool.AppendCertsFromPEM(caData)
+
+	lock.Lock()
+	defer lock.Unlock()
+	tlsConfigs[getClusterKey(endpoints)] = &tls.Config{
+		Certificates:       []tls.Certificate{cert},
+		RootCAs:            pool,
+		InsecureSkipVerify: insecureSkipVerify,
+	}
+
+	return nil
+}
+
 // GetAccount gets the username/password for the given etcd cluster.
 func GetAccount(endpoints []string) (Account, bool) {
 	lock.RLock()
@@ -32,3 +64,12 @@ func GetAccount(endpoints []string) (Account, bool) {
 	account, ok := accounts[getClusterKey(endpoints)]
 	return account, ok
 }
+
+// GetTLS gets the tls config for the given etcd cluster.
+func GetTLS(endpoints []string) (*tls.Config, bool) {
+	lock.RLock()
+	defer lock.RUnlock()
+
+	cfg, ok := tlsConfigs[getClusterKey(endpoints)]
+	return cfg, ok
+}

+ 3 - 0
core/discov/internal/registry.go

@@ -337,6 +337,9 @@ func DialClient(endpoints []string) (EtcdClient, error) {
 		cfg.Username = account.User
 		cfg.Password = account.Pass
 	}
+	if tlsCfg, ok := GetTLS(endpoints); ok {
+		cfg.TLS = tlsCfg
+	}
 
 	return clientv3.New(cfg)
 }

+ 12 - 5
core/discov/publisher.go

@@ -145,16 +145,23 @@ func (p *Publisher) revoke(cli internal.EtcdClient) {
 	}
 }
 
+// WithId customizes a Publisher with the id.
+func WithId(id int64) PubOption {
+	return func(publisher *Publisher) {
+		publisher.id = id
+	}
+}
+
 // WithPubEtcdAccount provides the etcd username/password.
 func WithPubEtcdAccount(user, pass string) PubOption {
 	return func(pub *Publisher) {
-		internal.AddAccount(pub.endpoints, user, pass)
+		RegisterAccount(pub.endpoints, user, pass)
 	}
 }
 
-// WithId customizes a Publisher with the id.
-func WithId(id int64) PubOption {
-	return func(publisher *Publisher) {
-		publisher.id = id
+// WithPubEtcdTLS provides the etcd CertFile/CertKeyFile/CACertFile.
+func WithPubEtcdTLS(certFile, certKeyFile, caFile string, insecureSkipVerify bool) PubOption {
+	return func(pub *Publisher) {
+		logx.Must(RegisterTLS(pub.endpoints, certFile, certKeyFile, caFile, insecureSkipVerify))
 	}
 }

+ 10 - 2
core/discov/subscriber.go

@@ -5,6 +5,7 @@ import (
 	"sync/atomic"
 
 	"github.com/tal-tech/go-zero/core/discov/internal"
+	"github.com/tal-tech/go-zero/core/logx"
 	"github.com/tal-tech/go-zero/core/syncx"
 )
 
@@ -58,10 +59,17 @@ func Exclusive() SubOption {
 	}
 }
 
-// WithSubEtcdAccount customizes the Subscriber with given etcd username/password.
+// WithSubEtcdAccount provides the etcd username/password.
 func WithSubEtcdAccount(user, pass string) SubOption {
 	return func(sub *Subscriber) {
-		internal.AddAccount(sub.endpoints, user, pass)
+		RegisterAccount(sub.endpoints, user, pass)
+	}
+}
+
+// WithSubEtcdTLS provides the etcd CertFile/CertKeyFile/CACertFile.
+func WithSubEtcdTLS(certFile, certKeyFile, caFile string, insecureSkipVerify bool) SubOption {
+	return func(sub *Subscriber) {
+		logx.Must(RegisterTLS(sub.endpoints, certFile, certKeyFile, caFile, insecureSkipVerify))
 	}
 }
 

+ 6 - 0
zrpc/config.go

@@ -83,6 +83,12 @@ func (cc RpcClientConf) BuildTarget() (string, error) {
 	if cc.Etcd.HasAccount() {
 		discov.RegisterAccount(cc.Etcd.Hosts, cc.Etcd.User, cc.Etcd.Pass)
 	}
+	if cc.Etcd.HasTLS() {
+		if err := discov.RegisterTLS(cc.Etcd.Hosts, cc.Etcd.CertFile, cc.Etcd.CertKeyFile,
+			cc.Etcd.CACertFile, cc.Etcd.InsecureSkipVerify); err != nil {
+			return "", err
+		}
+	}
 
 	return resolver.BuildDiscovTarget(cc.Etcd.Hosts, cc.Etcd.Key), nil
 }

+ 4 - 0
zrpc/internal/rpcpubserver.go

@@ -21,6 +21,10 @@ func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, opts ...ServerOption
 		if etcd.HasAccount() {
 			pubOpts = append(pubOpts, discov.WithPubEtcdAccount(etcd.User, etcd.Pass))
 		}
+		if etcd.HasTLS() {
+			pubOpts = append(pubOpts, discov.WithPubEtcdTLS(etcd.CertFile, etcd.CertKeyFile,
+				etcd.CACertFile, etcd.InsecureSkipVerify))
+		}
 		pubClient := discov.NewPublisher(etcd.Hosts, etcd.Key, pubListenOn, pubOpts...)
 		return pubClient.KeepAlive()
 	}