Преглед изворни кода

feat: add middlewares config for zrpc (#2766)

* feat: add middlewares config for zrpc

* chore: add tests

* chore: improve codecov

* chore: improve codecov
Kevin Wan пре 2 година
родитељ
комит
26c541b9cb

+ 9 - 2
zrpc/client.go

@@ -70,7 +70,7 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
 		return nil, err
 	}
 
-	client, err := internal.NewClient(target, opts...)
+	client, err := internal.NewClient(target, c.Middlewares, opts...)
 	if err != nil {
 		return nil, err
 	}
@@ -82,7 +82,14 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
 
 // NewClientWithTarget returns a Client with connecting to given target.
 func NewClientWithTarget(target string, opts ...ClientOption) (Client, error) {
-	return internal.NewClient(target, opts...)
+	middlewares := ClientMiddlewaresConf{
+		Trace:      true,
+		Duration:   true,
+		Prometheus: true,
+		Breaker:    true,
+		Timeout:    true,
+	}
+	return internal.NewClient(target, middlewares, opts...)
 }
 
 // Conn returns the underlying grpc.ClientConn.

+ 21 - 0
zrpc/client_test.go

@@ -76,6 +76,13 @@ func TestDepositServer_Deposit(t *testing.T) {
 			App:       "foo",
 			Token:     "bar",
 			Timeout:   1000,
+			Middlewares: ClientMiddlewaresConf{
+				Trace:      true,
+				Duration:   true,
+				Prometheus: true,
+				Breaker:    true,
+				Timeout:    true,
+			},
 		},
 		WithDialOption(grpc.WithContextDialer(dialer())),
 		WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply interface{},
@@ -90,6 +97,13 @@ func TestDepositServer_Deposit(t *testing.T) {
 			Token:     "bar",
 			Timeout:   1000,
 			NonBlock:  true,
+			Middlewares: ClientMiddlewaresConf{
+				Trace:      true,
+				Duration:   true,
+				Prometheus: true,
+				Breaker:    true,
+				Timeout:    true,
+			},
 		},
 		WithDialOption(grpc.WithContextDialer(dialer())),
 		WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply interface{},
@@ -103,6 +117,13 @@ func TestDepositServer_Deposit(t *testing.T) {
 			App:     "foo",
 			Token:   "bar",
 			Timeout: 1000,
+			Middlewares: ClientMiddlewaresConf{
+				Trace:      true,
+				Duration:   true,
+				Prometheus: true,
+				Breaker:    true,
+				Timeout:    true,
+			},
 		},
 		WithDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())),
 		WithDialOption(grpc.WithContextDialer(dialer())),

+ 16 - 8
zrpc/config.go

@@ -4,10 +4,16 @@ import (
 	"github.com/zeromicro/go-zero/core/discov"
 	"github.com/zeromicro/go-zero/core/service"
 	"github.com/zeromicro/go-zero/core/stores/redis"
+	"github.com/zeromicro/go-zero/zrpc/internal"
 	"github.com/zeromicro/go-zero/zrpc/resolver"
 )
 
 type (
+	// ClientMiddlewaresConf defines whether to use client middlewares.
+	ClientMiddlewaresConf = internal.ClientMiddlewaresConf
+	// ServerMiddlewaresConf defines whether to use server middlewares.
+	ServerMiddlewaresConf = internal.ServerMiddlewaresConf
+
 	// A RpcServerConf is a rpc server config.
 	RpcServerConf struct {
 		service.ServiceConf
@@ -20,18 +26,20 @@ type (
 		Timeout      int64 `json:",default=2000"`
 		CpuThreshold int64 `json:",default=900,range=[0:1000]"`
 		// grpc health check switch
-		Health bool `json:",default=true"`
+		Health      bool `json:",default=true"`
+		Middlewares ServerMiddlewaresConf
 	}
 
 	// A RpcClientConf is a rpc client config.
 	RpcClientConf struct {
-		Etcd      discov.EtcdConf `json:",optional,inherit"`
-		Endpoints []string        `json:",optional"`
-		Target    string          `json:",optional"`
-		App       string          `json:",optional"`
-		Token     string          `json:",optional"`
-		NonBlock  bool            `json:",optional"`
-		Timeout   int64           `json:",default=2000"`
+		Etcd        discov.EtcdConf `json:",optional,inherit"`
+		Endpoints   []string        `json:",optional"`
+		Target      string          `json:",optional"`
+		App         string          `json:",optional"`
+		Token       string          `json:",optional"`
+		NonBlock    bool            `json:",optional"`
+		Timeout     int64           `json:",default=2000"`
+		Middlewares ClientMiddlewaresConf
 	}
 )
 

+ 41 - 13
zrpc/internal/client.go

@@ -42,13 +42,17 @@ type (
 	ClientOption func(options *ClientOptions)
 
 	client struct {
-		conn *grpc.ClientConn
+		conn        *grpc.ClientConn
+		middlewares ClientMiddlewaresConf
 	}
 )
 
 // NewClient returns a Client.
-func NewClient(target string, opts ...ClientOption) (Client, error) {
-	var cli client
+func NewClient(target string, middlewares ClientMiddlewaresConf,
+	opts ...ClientOption) (Client, error) {
+	cli := client{
+		middlewares: middlewares,
+	}
 
 	svcCfg := fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, p2c.Name)
 	balancerOpt := WithDialOption(grpc.WithDefaultServiceConfig(svcCfg))
@@ -80,21 +84,45 @@ func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption {
 	}
 
 	options = append(options,
-		WithUnaryClientInterceptors(
-			clientinterceptors.UnaryTracingInterceptor,
-			clientinterceptors.DurationInterceptor,
-			clientinterceptors.PrometheusInterceptor,
-			clientinterceptors.BreakerInterceptor,
-			clientinterceptors.TimeoutInterceptor(cliOpts.Timeout),
-		),
-		WithStreamClientInterceptors(
-			clientinterceptors.StreamTracingInterceptor,
-		),
+		WithUnaryClientInterceptors(c.buildUnaryInterceptors(cliOpts.Timeout)...),
+		WithStreamClientInterceptors(c.buildStreamInterceptors()...),
 	)
 
 	return append(options, cliOpts.DialOptions...)
 }
 
+func (c *client) buildStreamInterceptors() []grpc.StreamClientInterceptor {
+	var interceptors []grpc.StreamClientInterceptor
+
+	if c.middlewares.Trace {
+		interceptors = append(interceptors, clientinterceptors.StreamTracingInterceptor)
+	}
+
+	return interceptors
+}
+
+func (c *client) buildUnaryInterceptors(timeout time.Duration) []grpc.UnaryClientInterceptor {
+	var interceptors []grpc.UnaryClientInterceptor
+
+	if c.middlewares.Trace {
+		interceptors = append(interceptors, clientinterceptors.UnaryTracingInterceptor)
+	}
+	if c.middlewares.Duration {
+		interceptors = append(interceptors, clientinterceptors.DurationInterceptor)
+	}
+	if c.middlewares.Prometheus {
+		interceptors = append(interceptors, clientinterceptors.PrometheusInterceptor)
+	}
+	if c.middlewares.Breaker {
+		interceptors = append(interceptors, clientinterceptors.BreakerInterceptor)
+	}
+	if c.middlewares.Timeout {
+		interceptors = append(interceptors, clientinterceptors.TimeoutInterceptor(timeout))
+	}
+
+	return interceptors
+}
+
 func (c *client) dial(server string, opts ...ClientOption) error {
 	options := c.buildDialOptions(opts...)
 	timeCtx, cancel := context.WithTimeout(context.Background(), dialTimeout)

+ 57 - 1
zrpc/internal/client_test.go

@@ -2,6 +2,8 @@ package internal
 
 import (
 	"context"
+	"net"
+	"strings"
 	"testing"
 	"time"
 
@@ -60,8 +62,62 @@ func TestWithUnaryClientInterceptor(t *testing.T) {
 }
 
 func TestBuildDialOptions(t *testing.T) {
-	var c client
+	c := client{
+		middlewares: ClientMiddlewaresConf{
+			Trace:      true,
+			Duration:   true,
+			Prometheus: true,
+			Breaker:    true,
+			Timeout:    true,
+		},
+	}
 	agent := grpc.WithUserAgent("chrome")
 	opts := c.buildDialOptions(WithDialOption(agent))
 	assert.Contains(t, opts, agent)
 }
+
+func TestClientDial(t *testing.T) {
+	server := grpc.NewServer()
+
+	go func() {
+		lis, err := net.Listen("tcp", "localhost:54321")
+		assert.NoError(t, err)
+		defer lis.Close()
+		server.Serve(lis)
+	}()
+
+	time.Sleep(time.Millisecond)
+
+	c, err := NewClient("localhost:54321", ClientMiddlewaresConf{
+		Trace:      true,
+		Duration:   true,
+		Prometheus: true,
+		Breaker:    true,
+		Timeout:    true,
+	})
+	assert.NoError(t, err)
+	assert.NotNil(t, c.Conn())
+	server.Stop()
+}
+
+func TestClientDialFail(t *testing.T) {
+	_, err := NewClient("localhost:54321", ClientMiddlewaresConf{
+		Trace:      true,
+		Duration:   true,
+		Prometheus: true,
+		Breaker:    true,
+		Timeout:    true,
+	})
+	assert.Error(t, err)
+	assert.True(t, strings.Contains(err.Error(), "localhost:54321"))
+
+	_, err = NewClient("localhost:54321/fail", ClientMiddlewaresConf{
+		Trace:      true,
+		Duration:   true,
+		Prometheus: true,
+		Breaker:    true,
+		Timeout:    true,
+	})
+	assert.Error(t, err)
+	assert.True(t, strings.Contains(err.Error(), "localhost:54321/fail"))
+}

+ 21 - 0
zrpc/internal/config.go

@@ -0,0 +1,21 @@
+package internal
+
+type (
+	// ClientMiddlewaresConf defines whether to use client middlewares.
+	ClientMiddlewaresConf struct {
+		Trace      bool `json:",default=true"`
+		Duration   bool `json:",default=true"`
+		Prometheus bool `json:",default=true"`
+		Breaker    bool `json:",default=true"`
+		Timeout    bool `json:",default=true"`
+	}
+
+	// ServerMiddlewaresConf defines whether to use server middlewares.
+	ServerMiddlewaresConf struct {
+		Trace      bool `json:",default=true"`
+		Recover    bool `json:",default=true"`
+		Stat       bool `json:",default=true"`
+		Prometheus bool `json:",default=true"`
+		Breaker    bool `json:",default=true"`
+	}
+)

+ 3 - 2
zrpc/internal/rpcpubserver.go

@@ -14,7 +14,8 @@ const (
 )
 
 // NewRpcPubServer returns a Server.
-func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, opts ...ServerOption) (Server, error) {
+func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, middlewares ServerMiddlewaresConf,
+	opts ...ServerOption) (Server, error) {
 	registerEtcd := func() error {
 		pubListenOn := figureOutListenOn(listenOn)
 		var pubOpts []discov.PubOption
@@ -30,7 +31,7 @@ func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, opts ...ServerOption
 	}
 	server := keepAliveServer{
 		registerEtcd: registerEtcd,
-		Server:       NewRpcServer(listenOn, opts...),
+		Server:       NewRpcServer(listenOn, middlewares, opts...),
 	}
 
 	return server, nil

+ 43 - 13
zrpc/internal/rpcserver.go

@@ -26,12 +26,13 @@ type (
 	rpcServer struct {
 		*baseRpcServer
 		name          string
+		middlewares   ServerMiddlewaresConf
 		healthManager health.Probe
 	}
 )
 
 // NewRpcServer returns a Server.
-func NewRpcServer(addr string, opts ...ServerOption) Server {
+func NewRpcServer(addr string, middlewares ServerMiddlewaresConf, opts ...ServerOption) Server {
 	var options rpcServerOptions
 	for _, opt := range opts {
 		opt(&options)
@@ -42,6 +43,7 @@ func NewRpcServer(addr string, opts ...ServerOption) Server {
 
 	return &rpcServer{
 		baseRpcServer: newBaseRpcServer(addr, &options),
+		middlewares:   middlewares,
 		healthManager: health.NewHealthManager(fmt.Sprintf("%s-%s", probeNamePrefix, addr)),
 	}
 }
@@ -57,19 +59,9 @@ func (s *rpcServer) Start(register RegisterFn) error {
 		return err
 	}
 
-	unaryInterceptors := []grpc.UnaryServerInterceptor{
-		serverinterceptors.UnaryTracingInterceptor,
-		serverinterceptors.UnaryCrashInterceptor,
-		serverinterceptors.UnaryStatInterceptor(s.metrics),
-		serverinterceptors.UnaryPrometheusInterceptor,
-		serverinterceptors.UnaryBreakerInterceptor,
-	}
+	unaryInterceptors := s.buildUnaryInterceptors()
 	unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...)
-	streamInterceptors := []grpc.StreamServerInterceptor{
-		serverinterceptors.StreamTracingInterceptor,
-		serverinterceptors.StreamCrashInterceptor,
-		serverinterceptors.StreamBreakerInterceptor,
-	}
+	streamInterceptors := s.buildStreamInterceptors()
 	streamInterceptors = append(streamInterceptors, s.streamInterceptors...)
 	options := append(s.options, WithUnaryServerInterceptors(unaryInterceptors...),
 		WithStreamServerInterceptors(streamInterceptors...))
@@ -97,6 +89,44 @@ func (s *rpcServer) Start(register RegisterFn) error {
 	return server.Serve(lis)
 }
 
+func (s *rpcServer) buildStreamInterceptors() []grpc.StreamServerInterceptor {
+	var interceptors []grpc.StreamServerInterceptor
+
+	if s.middlewares.Trace {
+		interceptors = append(interceptors, serverinterceptors.StreamTracingInterceptor)
+	}
+	if s.middlewares.Recover {
+		interceptors = append(interceptors, serverinterceptors.StreamRecoverInterceptor)
+	}
+	if s.middlewares.Breaker {
+		interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor)
+	}
+
+	return interceptors
+}
+
+func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
+	var interceptors []grpc.UnaryServerInterceptor
+
+	if s.middlewares.Trace {
+		interceptors = append(interceptors, serverinterceptors.UnaryTracingInterceptor)
+	}
+	if s.middlewares.Recover {
+		interceptors = append(interceptors, serverinterceptors.UnaryRecoverInterceptor)
+	}
+	if s.middlewares.Stat {
+		interceptors = append(interceptors, serverinterceptors.UnaryStatInterceptor(s.metrics))
+	}
+	if s.middlewares.Prometheus {
+		interceptors = append(interceptors, serverinterceptors.UnaryPrometheusInterceptor)
+	}
+	if s.middlewares.Breaker {
+		interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor)
+	}
+
+	return interceptors
+}
+
 // WithMetrics returns a func that sets metrics to a Server.
 func WithMetrics(metrics *stat.Metrics) ServerOption {
 	return func(options *rpcServerOptions) {

+ 14 - 2
zrpc/internal/rpcserver_test.go

@@ -12,7 +12,13 @@ import (
 
 func TestRpcServer(t *testing.T) {
 	metrics := stat.NewMetrics("foo")
-	server := NewRpcServer("localhost:54321", WithMetrics(metrics))
+	server := NewRpcServer("localhost:54321", ServerMiddlewaresConf{
+		Trace:      true,
+		Recover:    true,
+		Stat:       true,
+		Prometheus: true,
+		Breaker:    true,
+	}, WithMetrics(metrics))
 	server.SetName("mock")
 	var wg sync.WaitGroup
 	var grpcServer *grpc.Server
@@ -36,7 +42,13 @@ func TestRpcServer(t *testing.T) {
 }
 
 func TestRpcServer_WithBadAddress(t *testing.T) {
-	server := NewRpcServer("localhost:111111")
+	server := NewRpcServer("localhost:111111", ServerMiddlewaresConf{
+		Trace:      true,
+		Recover:    true,
+		Stat:       true,
+		Prometheus: true,
+		Breaker:    true,
+	})
 	server.SetName("mock")
 	err := server.Start(func(server *grpc.Server) {
 		mock.RegisterDepositServiceServer(server, new(mock.DepositServer))

+ 4 - 4
zrpc/internal/serverinterceptors/crashinterceptor.go → zrpc/internal/serverinterceptors/recoverinterceptor.go

@@ -10,8 +10,8 @@ import (
 	"google.golang.org/grpc/status"
 )
 
-// StreamCrashInterceptor catches panics in processing stream requests and recovers.
-func StreamCrashInterceptor(svr interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
+// StreamRecoverInterceptor catches panics in processing stream requests and recovers.
+func StreamRecoverInterceptor(svr interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
 	handler grpc.StreamHandler) (err error) {
 	defer handleCrash(func(r interface{}) {
 		err = toPanicError(r)
@@ -20,8 +20,8 @@ func StreamCrashInterceptor(svr interface{}, stream grpc.ServerStream, _ *grpc.S
 	return handler(svr, stream)
 }
 
-// UnaryCrashInterceptor catches panics in processing unary requests and recovers.
-func UnaryCrashInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo,
+// UnaryRecoverInterceptor catches panics in processing unary requests and recovers.
+func UnaryRecoverInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo,
 	handler grpc.UnaryHandler) (resp interface{}, err error) {
 	defer handleCrash(func(r interface{}) {
 		err = toPanicError(r)

+ 2 - 2
zrpc/internal/serverinterceptors/crashinterceptor_test.go → zrpc/internal/serverinterceptors/recoverinterceptor_test.go

@@ -14,7 +14,7 @@ func init() {
 }
 
 func TestStreamCrashInterceptor(t *testing.T) {
-	err := StreamCrashInterceptor(nil, nil, nil, func(
+	err := StreamRecoverInterceptor(nil, nil, nil, func(
 		svr interface{}, stream grpc.ServerStream) error {
 		panic("mock panic")
 	})
@@ -22,7 +22,7 @@ func TestStreamCrashInterceptor(t *testing.T) {
 }
 
 func TestUnaryCrashInterceptor(t *testing.T) {
-	_, err := UnaryCrashInterceptor(context.Background(), nil, nil,
+	_, err := UnaryRecoverInterceptor(context.Background(), nil, nil,
 		func(ctx context.Context, req interface{}) (interface{}, error) {
 			panic("mock panic")
 		})

+ 2 - 2
zrpc/server.go

@@ -44,12 +44,12 @@ func NewServer(c RpcServerConf, register internal.RegisterFn) (*RpcServer, error
 	}
 
 	if c.HasEtcd() {
-		server, err = internal.NewRpcPubServer(c.Etcd, c.ListenOn, serverOptions...)
+		server, err = internal.NewRpcPubServer(c.Etcd, c.ListenOn, c.Middlewares, serverOptions...)
 		if err != nil {
 			return nil, err
 		}
 	} else {
-		server = internal.NewRpcServer(c.ListenOn, serverOptions...)
+		server = internal.NewRpcServer(c.ListenOn, c.Middlewares, serverOptions...)
 	}
 
 	server.SetName(c.Name)

+ 39 - 4
zrpc/server_test.go

@@ -28,6 +28,13 @@ func TestServer_setupInterceptors(t *testing.T) {
 		},
 		CpuThreshold: 10,
 		Timeout:      100,
+		Middlewares: ServerMiddlewaresConf{
+			Trace:      true,
+			Recover:    true,
+			Stat:       true,
+			Prometheus: true,
+			Breaker:    true,
+		},
 	}, new(stat.Metrics))
 	assert.Nil(t, err)
 	assert.Equal(t, 3, len(server.unaryInterceptors))
@@ -51,11 +58,18 @@ func TestServer(t *testing.T) {
 		StrictControl: false,
 		Timeout:       0,
 		CpuThreshold:  0,
+		Middlewares: ServerMiddlewaresConf{
+			Trace:      true,
+			Recover:    true,
+			Stat:       true,
+			Prometheus: true,
+			Breaker:    true,
+		},
 	}, func(server *grpc.Server) {
 	})
 	svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
-	svr.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor)
-	svr.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor)
+	svr.AddUnaryInterceptors(serverinterceptors.UnaryRecoverInterceptor)
+	svr.AddStreamInterceptors(serverinterceptors.StreamRecoverInterceptor)
 	go svr.Start()
 	svr.Stop()
 }
@@ -74,6 +88,13 @@ func TestServerError(t *testing.T) {
 		},
 		Auth:  true,
 		Redis: redis.RedisKeyConf{},
+		Middlewares: ServerMiddlewaresConf{
+			Trace:      true,
+			Recover:    true,
+			Stat:       true,
+			Prometheus: true,
+			Breaker:    true,
+		},
 	}, func(server *grpc.Server) {
 	})
 	assert.NotNil(t, err)
@@ -93,11 +114,18 @@ func TestServer_HasEtcd(t *testing.T) {
 			Key:   "any",
 		},
 		Redis: redis.RedisKeyConf{},
+		Middlewares: ServerMiddlewaresConf{
+			Trace:      true,
+			Recover:    true,
+			Stat:       true,
+			Prometheus: true,
+			Breaker:    true,
+		},
 	}, func(server *grpc.Server) {
 	})
 	svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
-	svr.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor)
-	svr.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor)
+	svr.AddUnaryInterceptors(serverinterceptors.UnaryRecoverInterceptor)
+	svr.AddStreamInterceptors(serverinterceptors.StreamRecoverInterceptor)
 	go svr.Start()
 	svr.Stop()
 }
@@ -111,6 +139,13 @@ func TestServer_StartFailed(t *testing.T) {
 			},
 		},
 		ListenOn: "localhost:aaa",
+		Middlewares: ServerMiddlewaresConf{
+			Trace:      true,
+			Recover:    true,
+			Stat:       true,
+			Prometheus: true,
+			Breaker:    true,
+		},
 	}, func(server *grpc.Server) {
 	})