浏览代码

chore: refactor zrpc timeout (#3671)

Kevin Wan 1 年之前
父节点
当前提交
922efbfc2d

+ 3 - 3
zrpc/client.go

@@ -111,7 +111,7 @@ func SetClientSlowThreshold(threshold time.Duration) {
 	clientinterceptors.SetSlowThreshold(threshold)
 }
 
-// WithTimeoutCallOption return a call option with given timeout.
-func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
-	return clientinterceptors.WithTimeoutCallOption(timeout)
+// WithCallTimeout return a call option with given timeout to make a method call.
+func WithCallTimeout(timeout time.Duration) grpc.CallOption {
+	return clientinterceptors.WithCallTimeout(timeout)
 }

+ 14 - 14
zrpc/client_test.go

@@ -41,12 +41,12 @@ func dialer() func(context.Context, string) (net.Conn, error) {
 
 func TestDepositServer_Deposit(t *testing.T) {
 	tests := []struct {
-		name              string
-		amount            float32
-		timeoutCallOption time.Duration
-		res               *mock.DepositResponse
-		errCode           codes.Code
-		errMsg            string
+		name    string
+		amount  float32
+		timeout time.Duration
+		res     *mock.DepositResponse
+		errCode codes.Code
+		errMsg  string
 	}{
 		{
 			name:    "invalid request with negative amount",
@@ -66,12 +66,12 @@ func TestDepositServer_Deposit(t *testing.T) {
 			errMsg:  "context deadline exceeded",
 		},
 		{
-			name:              "valid request with timeout call option",
-			amount:            2000.00,
-			timeoutCallOption: time.Second * 3,
-			res:               &mock.DepositResponse{Ok: true},
-			errCode:           codes.OK,
-			errMsg:            "",
+			name:    "valid request with timeout call option",
+			amount:  2000.00,
+			timeout: time.Second * 3,
+			res:     &mock.DepositResponse{Ok: true},
+			errCode: codes.OK,
+			errMsg:  "",
 		},
 	}
 
@@ -171,8 +171,8 @@ func TestDepositServer_Deposit(t *testing.T) {
 					err      error
 				)
 
-				if tt.timeoutCallOption > 0 {
-					response, err = cli.Deposit(ctx, request, WithTimeoutCallOption(tt.timeoutCallOption))
+				if tt.timeout > 0 {
+					response, err = cli.Deposit(ctx, request, WithCallTimeout(tt.timeout))
 				} else {
 					response, err = cli.Deposit(ctx, request)
 				}

+ 3 - 3
zrpc/config.go

@@ -17,8 +17,8 @@ type (
 	ServerMiddlewaresConf = internal.ServerMiddlewaresConf
 	// StatConf defines the stat config.
 	StatConf = internal.StatConf
-	// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
-	ServerSpecifiedTimeoutConf = internal.ServerSpecifiedTimeoutConf
+	// MethodTimeoutConf defines specified timeout for gRPC method.
+	MethodTimeoutConf = internal.MethodTimeoutConf
 
 	// A RpcClientConf is a rpc client config.
 	RpcClientConf struct {
@@ -48,7 +48,7 @@ type (
 		Health      bool `json:",default=true"`
 		Middlewares ServerMiddlewaresConf
 		// setting specified timeout for gRPC method
-		SpecifiedTimeouts []ServerSpecifiedTimeoutConf `json:",optional"`
+		MethodTimeouts []MethodTimeoutConf `json:",optional"`
 	}
 )
 

+ 17 - 16
zrpc/internal/clientinterceptors/timeoutinterceptor.go

@@ -7,11 +7,17 @@ import (
 	"google.golang.org/grpc"
 )
 
+// TimeoutCallOption is a call option that controls timeout.
+type TimeoutCallOption struct {
+	grpc.EmptyCallOption
+	timeout time.Duration
+}
+
 // TimeoutInterceptor is an interceptor that controls timeout.
 func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
 	return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
 		invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
-		t := getTimeoutByCallOptions(opts, timeout)
+		t := getTimeoutFromCallOptions(opts, timeout)
 		if t <= 0 {
 			return invoker(ctx, method, req, reply, cc, opts...)
 		}
@@ -23,24 +29,19 @@ func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
 	}
 }
 
-func getTimeoutByCallOptions(callOptions []grpc.CallOption, defaultTimeout time.Duration) time.Duration {
-	for _, callOption := range callOptions {
-		if o, ok := callOption.(TimeoutCallOption); ok {
+// WithCallTimeout returns a call option that controls method call timeout.
+func WithCallTimeout(timeout time.Duration) grpc.CallOption {
+	return TimeoutCallOption{
+		timeout: timeout,
+	}
+}
+
+func getTimeoutFromCallOptions(opts []grpc.CallOption, defaultTimeout time.Duration) time.Duration {
+	for _, opt := range opts {
+		if o, ok := opt.(TimeoutCallOption); ok {
 			return o.timeout
 		}
 	}
 
 	return defaultTimeout
 }
-
-type TimeoutCallOption struct {
-	grpc.EmptyCallOption
-
-	timeout time.Duration
-}
-
-func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
-	return TimeoutCallOption{
-		timeout: timeout,
-	}
-}

+ 1 - 1
zrpc/internal/clientinterceptors/timeoutinterceptor_test.go

@@ -114,7 +114,7 @@ func TestTimeoutInterceptor_TimeoutCallOption(t *testing.T) {
 			cc := new(grpc.ClientConn)
 			var co []grpc.CallOption
 			if tt.args.callOptionTimeout > 0 {
-				co = append(co, WithTimeoutCallOption(tt.args.callOptionTimeout))
+				co = append(co, WithCallTimeout(tt.args.callOptionTimeout))
 			}
 
 			err := interceptor(context.Background(), "/foo", nil, nil, cc,

+ 2 - 1
zrpc/internal/config.go

@@ -25,5 +25,6 @@ type (
 		Breaker    bool     `json:",default=true"`
 	}
 
-	ServerSpecifiedTimeoutConf = serverinterceptors.ServerSpecifiedTimeoutConf
+	// MethodTimeoutConf defines specified timeout for gRPC methods.
+	MethodTimeoutConf = serverinterceptors.MethodTimeoutConf
 )

+ 15 - 19
zrpc/internal/serverinterceptors/timeoutinterceptor.go

@@ -15,21 +15,22 @@ import (
 )
 
 type (
-	// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
-	ServerSpecifiedTimeoutConf struct {
+	// MethodTimeoutConf defines specified timeout for gRPC method.
+	MethodTimeoutConf struct {
 		FullMethod string
 		Timeout    time.Duration
 	}
 
-	specifiedTimeoutCache map[string]time.Duration
+	methodTimeouts map[string]time.Duration
 )
 
 // UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
-func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerSpecifiedTimeoutConf) grpc.UnaryServerInterceptor {
-	cache := cacheSpecifiedTimeout(specifiedTimeouts)
+func UnaryTimeoutInterceptor(timeout time.Duration,
+	methodTimeouts ...MethodTimeoutConf) grpc.UnaryServerInterceptor {
+	timeouts := buildMethodTimeouts(methodTimeouts)
 	return func(ctx context.Context, req any, info *grpc.UnaryServerInfo,
 		handler grpc.UnaryHandler) (any, error) {
-		t := getTimeoutByUnaryServerInfo(info, timeout, cache)
+		t := getTimeoutByUnaryServerInfo(info.FullMethod, timeouts, timeout)
 		ctx, cancel := context.WithTimeout(ctx, t)
 		defer cancel()
 
@@ -72,27 +73,22 @@ func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerS
 	}
 }
 
-func cacheSpecifiedTimeout(specifiedTimeouts []ServerSpecifiedTimeoutConf) specifiedTimeoutCache {
-	cache := make(specifiedTimeoutCache, len(specifiedTimeouts))
-	for _, st := range specifiedTimeouts {
+func buildMethodTimeouts(timeouts []MethodTimeoutConf) methodTimeouts {
+	mt := make(methodTimeouts, len(timeouts))
+	for _, st := range timeouts {
 		if st.FullMethod != "" {
-			cache[st.FullMethod] = st.Timeout
+			mt[st.FullMethod] = st.Timeout
 		}
 	}
 
-	return cache
+	return mt
 }
 
-func getTimeoutByUnaryServerInfo(info *grpc.UnaryServerInfo, defaultTimeout time.Duration, specifiedTimeout specifiedTimeoutCache) time.Duration {
-	if ts, ok := info.Server.(TimeoutStrategy); ok {
-		return ts.GetTimeoutByFullMethod(info.FullMethod, defaultTimeout)
-	} else if v, ok := specifiedTimeout[info.FullMethod]; ok {
+func getTimeoutByUnaryServerInfo(method string, timeouts methodTimeouts,
+	defaultTimeout time.Duration) time.Duration {
+	if v, ok := timeouts[method]; ok {
 		return v
 	}
 
 	return defaultTimeout
 }
-
-type TimeoutStrategy interface {
-	GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration
-}

+ 2 - 20
zrpc/internal/serverinterceptors/timeoutinterceptor_test.go

@@ -103,13 +103,6 @@ type tempServer struct {
 func (s *tempServer) run(duration time.Duration) {
 	time.Sleep(duration)
 }
-func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration {
-	if fullMethod == "/" {
-		return defaultTimeout
-	}
-
-	return s.timeout
-}
 
 func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
 	type args struct {
@@ -136,17 +129,6 @@ func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
 			},
 			wantErr: nil,
 		},
-		{
-			name: "do not timeout with timeout strategy",
-			args: args{
-				interceptorTimeout: time.Second,
-				contextTimeout:     time.Second * 5,
-				serverTimeout:      time.Second * 3,
-				runTime:            time.Second * 2,
-				fullMethod:         "/2s",
-			},
-			wantErr: nil,
-		},
 		{
 			name: "timeout with interceptor timeout",
 			args: args{
@@ -235,9 +217,9 @@ func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
 			t.Parallel()
 
-			var specifiedTimeouts []ServerSpecifiedTimeoutConf
+			var specifiedTimeouts []MethodTimeoutConf
 			if tt.args.methodTimeout > 0 {
-				specifiedTimeouts = []ServerSpecifiedTimeoutConf{
+				specifiedTimeouts = []MethodTimeoutConf{
 					{
 						FullMethod: tt.args.method,
 						Timeout:    tt.args.methodTimeout,

+ 2 - 6
zrpc/server.go

@@ -131,12 +131,8 @@ func setupInterceptors(svr internal.Server, c RpcServerConf, metrics *stat.Metri
 	}
 
 	if c.Timeout > 0 {
-		svr.AddUnaryInterceptors(
-			serverinterceptors.UnaryTimeoutInterceptor(
-				time.Duration(c.Timeout)*time.Millisecond,
-				c.SpecifiedTimeouts...,
-			),
-		)
+		svr.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor(
+			time.Duration(c.Timeout)*time.Millisecond, c.MethodTimeouts...))
 	}
 
 	if c.Auth {

+ 4 - 4
zrpc/server_test.go

@@ -40,7 +40,7 @@ func TestServer_setupInterceptors(t *testing.T) {
 			Prometheus: true,
 			Breaker:    true,
 		},
-		SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{
+		MethodTimeouts: []MethodTimeoutConf{
 			{
 				FullMethod: "/foo",
 				Timeout:    5 * time.Second,
@@ -81,7 +81,7 @@ func TestServer(t *testing.T) {
 			Prometheus: true,
 			Breaker:    true,
 		},
-		SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{
+		MethodTimeouts: []MethodTimeoutConf{
 			{
 				FullMethod: "/foo",
 				Timeout:    time.Second,
@@ -117,7 +117,7 @@ func TestServerError(t *testing.T) {
 			Prometheus: true,
 			Breaker:    true,
 		},
-		SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{},
+		MethodTimeouts: []MethodTimeoutConf{},
 	}, func(server *grpc.Server) {
 	})
 	assert.NotNil(t, err)
@@ -144,7 +144,7 @@ func TestServer_HasEtcd(t *testing.T) {
 			Prometheus: true,
 			Breaker:    true,
 		},
-		SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{},
+		MethodTimeouts: []MethodTimeoutConf{},
 	}, func(server *grpc.Server) {
 	})
 	svr.AddOptions(grpc.ConnectionTimeout(time.Hour))