Browse Source

support direct scheme on rpc resolver

kevin 4 years ago
parent
commit
f03cfb0ff7

+ 26 - 6
example/rpc/client/lb/main.go

@@ -2,7 +2,9 @@ package main
 
 import (
 	"context"
+	"flag"
 	"fmt"
+	"log"
 	"time"
 
 	"github.com/tal-tech/go-zero/core/discov"
@@ -10,13 +12,31 @@ import (
 	"github.com/tal-tech/go-zero/rpcx"
 )
 
+var lb = flag.String("t", "direct", "the load balancer type")
+
 func main() {
-	cli := rpcx.MustNewClient(rpcx.RpcClientConf{
-		Etcd: discov.EtcdConf{
-			Hosts: []string{"localhost:2379"},
-			Key:   "rpcx",
-		},
-	})
+	flag.Parse()
+
+	var cli rpcx.Client
+	switch *lb {
+	case "direct":
+		cli = rpcx.MustNewClient(rpcx.RpcClientConf{
+			Endpoints: []string{
+				"localhost:3456",
+				"localhost:3457",
+			},
+		})
+	case "discov":
+		cli = rpcx.MustNewClient(rpcx.RpcClientConf{
+			Etcd: discov.EtcdConf{
+				Hosts: []string{"localhost:2379"},
+				Key:   "rpcx",
+			},
+		})
+	default:
+		log.Fatal("bad load balancing type")
+	}
+
 	greet := unary.NewGreeterClient(cli.Conn())
 	ticker := time.NewTicker(time.Second)
 	defer ticker.Stop()

+ 8 - 4
rpcx/client.go

@@ -49,10 +49,10 @@ func NewClient(c RpcClientConf, options ...internal.ClientOption) (Client, error
 
 	var client Client
 	var err error
-	if len(c.Server) > 0 {
-		client, err = internal.NewDirectClient(c.Server, opts...)
+	if len(c.Endpoints) > 0 {
+		client, err = internal.NewClient(internal.BuildDirectTarget(c.Endpoints), opts...)
 	} else if err = c.Etcd.Validate(); err == nil {
-		client, err = internal.NewDiscovClient(c.Etcd.Hosts, c.Etcd.Key, opts...)
+		client, err = internal.NewClient(internal.BuildDiscovTarget(c.Etcd.Hosts, c.Etcd.Key), opts...)
 	}
 	if err != nil {
 		return nil, err
@@ -64,7 +64,7 @@ func NewClient(c RpcClientConf, options ...internal.ClientOption) (Client, error
 }
 
 func NewClientNoAuth(c discov.EtcdConf) (Client, error) {
-	client, err := internal.NewDiscovClient(c.Hosts, c.Key)
+	client, err := internal.NewClient(internal.BuildDiscovTarget(c.Hosts, c.Key))
 	if err != nil {
 		return nil, err
 	}
@@ -74,6 +74,10 @@ func NewClientNoAuth(c discov.EtcdConf) (Client, error) {
 	}, nil
 }
 
+func NewClientWithTarget(target string, opts ...internal.ClientOption) (Client, error) {
+	return internal.NewClient(target, opts...)
+}
+
 func (rc *RpcClient) Conn() *grpc.ClientConn {
 	return rc.client.Conn()
 }

+ 9 - 9
rpcx/config.go

@@ -21,19 +21,19 @@ type (
 	}
 
 	RpcClientConf struct {
-		Etcd    discov.EtcdConf `json:",optional"`
-		Server  string          `json:",optional=!Etcd"`
-		App     string          `json:",optional"`
-		Token   string          `json:",optional"`
-		Timeout int64           `json:",optional"`
+		Etcd      discov.EtcdConf `json:",optional"`
+		Endpoints []string        `json:",optional=!Etcd"`
+		App       string          `json:",optional"`
+		Token     string          `json:",optional"`
+		Timeout   int64           `json:",optional"`
 	}
 )
 
-func NewDirectClientConf(server, app, token string) RpcClientConf {
+func NewDirectClientConf(endpoints []string, app, token string) RpcClientConf {
 	return RpcClientConf{
-		Server: server,
-		App:    app,
-		Token:  token,
+		Endpoints: endpoints,
+		App:       app,
+		Token:     token,
 	}
 }
 

+ 24 - 0
rpcx/internal/client.go

@@ -5,12 +5,18 @@ import (
 	"fmt"
 	"time"
 
+	"github.com/tal-tech/go-zero/rpcx/internal/balancer/p2c"
 	"github.com/tal-tech/go-zero/rpcx/internal/clientinterceptors"
+	"github.com/tal-tech/go-zero/rpcx/internal/resolver"
 	"google.golang.org/grpc"
 )
 
 const dialTimeout = time.Second * 3
 
+func init() {
+	resolver.RegisterResolver()
+}
+
 type (
 	ClientOptions struct {
 		Timeout     time.Duration
@@ -18,8 +24,26 @@ type (
 	}
 
 	ClientOption func(options *ClientOptions)
+
+	client struct {
+		conn *grpc.ClientConn
+	}
 )
 
+func NewClient(target string, opts ...ClientOption) (*client, error) {
+	opts = append(opts, WithDialOption(grpc.WithBalancerName(p2c.Name)))
+	conn, err := dial(target, opts...)
+	if err != nil {
+		return nil, err
+	}
+
+	return &client{conn: conn}, nil
+}
+
+func (c *client) Conn() *grpc.ClientConn {
+	return c.conn
+}
+
 func WithDialOption(opt grpc.DialOption) ClientOption {
 	return func(options *ClientOptions) {
 		options.DialOptions = append(options.DialOptions, opt)

+ 0 - 26
rpcx/internal/directclient.go

@@ -1,26 +0,0 @@
-package internal
-
-import (
-	"google.golang.org/grpc"
-	"google.golang.org/grpc/balancer/roundrobin"
-)
-
-type DirectClient struct {
-	conn *grpc.ClientConn
-}
-
-func NewDirectClient(server string, opts ...ClientOption) (*DirectClient, error) {
-	opts = append(opts, WithDialOption(grpc.WithBalancerName(roundrobin.Name)))
-	conn, err := dial(server, opts...)
-	if err != nil {
-		return nil, err
-	}
-
-	return &DirectClient{
-		conn: conn,
-	}, nil
-}
-
-func (c *DirectClient) Conn() *grpc.ClientConn {
-	return c.conn
-}

+ 0 - 34
rpcx/internal/discovclient.go

@@ -1,34 +0,0 @@
-package internal
-
-import (
-	"fmt"
-	"strings"
-
-	"github.com/tal-tech/go-zero/rpcx/internal/balancer/p2c"
-	"github.com/tal-tech/go-zero/rpcx/internal/resolver"
-	"google.golang.org/grpc"
-)
-
-func init() {
-	resolver.RegisterResolver()
-}
-
-type DiscovClient struct {
-	conn *grpc.ClientConn
-}
-
-func NewDiscovClient(endpoints []string, key string, opts ...ClientOption) (*DiscovClient, error) {
-	opts = append(opts, WithDialOption(grpc.WithBalancerName(p2c.Name)))
-	target := fmt.Sprintf("%s://%s/%s", resolver.DiscovScheme,
-		strings.Join(endpoints, resolver.EndpointSep), key)
-	conn, err := dial(target, opts...)
-	if err != nil {
-		return nil, err
-	}
-
-	return &DiscovClient{conn: conn}, nil
-}
-
-func (c *DiscovClient) Conn() *grpc.ClientConn {
-	return c.conn
-}

+ 30 - 0
rpcx/internal/resolver/directbuilder.go

@@ -0,0 +1,30 @@
+package resolver
+
+import (
+	"strings"
+
+	"google.golang.org/grpc/resolver"
+)
+
+type directBuilder struct{}
+
+func (d *directBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
+	resolver.Resolver, error) {
+	var addrs []resolver.Address
+	endpoints := strings.Split(target.Endpoint, EndpointSep)
+
+	for _, val := range subset(endpoints, subsetSize) {
+		addrs = append(addrs, resolver.Address{
+			Addr: val,
+		})
+	}
+	cc.UpdateState(resolver.State{
+		Addresses: addrs,
+	})
+
+	return &nopResolver{cc: cc}, nil
+}
+
+func (d *directBuilder) Scheme() string {
+	return DirectScheme
+}

+ 39 - 0
rpcx/internal/resolver/discovbuilder.go

@@ -0,0 +1,39 @@
+package resolver
+
+import (
+	"strings"
+
+	"github.com/tal-tech/go-zero/core/discov"
+	"google.golang.org/grpc/resolver"
+)
+
+type discovBuilder struct{}
+
+func (d *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
+	resolver.Resolver, error) {
+	hosts := strings.Split(target.Authority, EndpointSep)
+	sub, err := discov.NewSubscriber(hosts, target.Endpoint)
+	if err != nil {
+		return nil, err
+	}
+
+	update := func() {
+		var addrs []resolver.Address
+		for _, val := range subset(sub.Values(), subsetSize) {
+			addrs = append(addrs, resolver.Address{
+				Addr: val,
+			})
+		}
+		cc.UpdateState(resolver.State{
+			Addresses: addrs,
+		})
+	}
+	sub.AddListener(update)
+	update()
+
+	return &nopResolver{cc: cc}, nil
+}
+
+func (d *discovBuilder) Scheme() string {
+	return DiscovScheme
+}

+ 12 - 50
rpcx/internal/resolver/resolver.go

@@ -1,68 +1,30 @@
 package resolver
 
-import (
-	"fmt"
-	"strings"
-
-	"github.com/tal-tech/go-zero/core/discov"
-	"google.golang.org/grpc/resolver"
-)
+import "google.golang.org/grpc/resolver"
 
 const (
+	DirectScheme = "direct"
 	DiscovScheme = "discov"
 	EndpointSep  = ","
 	subsetSize   = 32
 )
 
-var builder discovBuilder
-
-type discovBuilder struct{}
-
-func (b *discovBuilder) Scheme() string {
-	return DiscovScheme
-}
-
-func (b *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
-	resolver.Resolver, error) {
-	if target.Scheme != DiscovScheme {
-		return nil, fmt.Errorf("bad scheme: %s", target.Scheme)
-	}
-
-	hosts := strings.Split(target.Authority, EndpointSep)
-	sub, err := discov.NewSubscriber(hosts, target.Endpoint)
-	if err != nil {
-		return nil, err
-	}
-
-	update := func() {
-		var addrs []resolver.Address
-		for _, val := range subset(sub.Values(), subsetSize) {
-			addrs = append(addrs, resolver.Address{
-				Addr: val,
-			})
-		}
-		cc.UpdateState(resolver.State{
-			Addresses: addrs,
-		})
-	}
-	sub.AddListener(update)
-	update()
+var (
+	dirBuilder directBuilder
+	disBuilder discovBuilder
+)
 
-	return &discovResolver{
-		cc: cc,
-	}, nil
+func RegisterResolver() {
+	resolver.Register(&dirBuilder)
+	resolver.Register(&disBuilder)
 }
 
-type discovResolver struct {
+type nopResolver struct {
 	cc resolver.ClientConn
 }
 
-func (r *discovResolver) Close() {
+func (r *nopResolver) Close() {
 }
 
-func (r *discovResolver) ResolveNow(options resolver.ResolveNowOptions) {
-}
-
-func RegisterResolver() {
-	resolver.Register(&builder)
+func (r *nopResolver) ResolveNow(options resolver.ResolveNowOptions) {
 }

+ 17 - 0
rpcx/internal/target.go

@@ -0,0 +1,17 @@
+package internal
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/tal-tech/go-zero/rpcx/internal/resolver"
+)
+
+func BuildDirectTarget(endpoints []string) string {
+	return fmt.Sprintf("%s:///%s", resolver.DirectScheme, strings.Join(endpoints, resolver.EndpointSep))
+}
+
+func BuildDiscovTarget(endpoints []string, key string) string {
+	return fmt.Sprintf("%s://%s/%s", resolver.DiscovScheme,
+		strings.Join(endpoints, resolver.EndpointSep), key)
+}

+ 5 - 5
rpcx/proxy.go

@@ -38,11 +38,11 @@ func (p *RpcProxy) TakeConn(ctx context.Context) (*grpc.ClientConn, error) {
 			return client, nil
 		}
 
-		client, err := NewClient(RpcClientConf{
-			Server: p.backend,
-			App:    cred.App,
-			Token:  cred.Token,
-		}, p.options...)
+		opts := append(p.options, WithDialOption(grpc.WithPerRPCCredentials(&auth.Credential{
+			App:   cred.App,
+			Token: cred.Token,
+		})))
+		client, err := NewClientWithTarget(p.backend, opts...)
 		if err != nil {
 			return nil, err
 		}