浏览代码

feat: support the specified timeout of rpc methods (#2742)

Co-authored-by: hanzijian <hanzijian@52tt.com>
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
vankillua 1 年之前
父节点
当前提交
842c4d81cc

+ 5 - 0
zrpc/client.go

@@ -110,3 +110,8 @@ func DontLogClientContentForMethod(method string) {
 func SetClientSlowThreshold(threshold time.Duration) {
 	clientinterceptors.SetSlowThreshold(threshold)
 }
+
+// WithTimeoutCallOption return a call option with given timeout.
+func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
+	return clientinterceptors.WithTimeoutCallOption(timeout)
+}

+ 39 - 21
zrpc/client_test.go

@@ -41,32 +41,37 @@ func dialer() func(context.Context, string) (net.Conn, error) {
 
 func TestDepositServer_Deposit(t *testing.T) {
 	tests := []struct {
-		name    string
-		amount  float32
-		res     *mock.DepositResponse
-		errCode codes.Code
-		errMsg  string
+		name              string
+		amount            float32
+		timeoutCallOption time.Duration
+		res               *mock.DepositResponse
+		errCode           codes.Code
+		errMsg            string
 	}{
 		{
-			"invalid request with negative amount",
-			-1.11,
-			nil,
-			codes.InvalidArgument,
-			fmt.Sprintf("cannot deposit %v", -1.11),
+			name:    "invalid request with negative amount",
+			amount:  -1.11,
+			errCode: codes.InvalidArgument,
+			errMsg:  fmt.Sprintf("cannot deposit %v", -1.11),
 		},
 		{
-			"valid request with non negative amount",
-			0.00,
-			&mock.DepositResponse{Ok: true},
-			codes.OK,
-			"",
+			name:    "valid request with non negative amount",
+			res:     &mock.DepositResponse{Ok: true},
+			errCode: codes.OK,
 		},
 		{
-			"valid request with long handling time",
-			2000.00,
-			nil,
-			codes.DeadlineExceeded,
-			"context deadline exceeded",
+			name:    "valid request with long handling time",
+			amount:  2000.00,
+			errCode: codes.DeadlineExceeded,
+			errMsg:  "context deadline exceeded",
+		},
+		{
+			name:              "valid request with timeout call option",
+			amount:            2000.00,
+			timeoutCallOption: time.Second * 3,
+			res:               &mock.DepositResponse{Ok: true},
+			errCode:           codes.OK,
+			errMsg:            "",
 		},
 	}
 
@@ -156,9 +161,22 @@ func TestDepositServer_Deposit(t *testing.T) {
 			client := client
 			t.Run(tt.name, func(t *testing.T) {
 				t.Parallel()
+
 				cli := mock.NewDepositServiceClient(client.Conn())
 				request := &mock.DepositRequest{Amount: tt.amount}
-				response, err := cli.Deposit(context.Background(), request)
+
+				var (
+					ctx      = context.Background()
+					response *mock.DepositResponse
+					err      error
+				)
+
+				if tt.timeoutCallOption > 0 {
+					response, err = cli.Deposit(ctx, request, WithTimeoutCallOption(tt.timeoutCallOption))
+				} else {
+					response, err = cli.Deposit(ctx, request)
+				}
+
 				if response != nil {
 					assert.True(t, len(response.String()) > 0)
 					if response.GetOk() != tt.res.GetOk() {

+ 4 - 0
zrpc/config.go

@@ -17,6 +17,8 @@ type (
 	ServerMiddlewaresConf = internal.ServerMiddlewaresConf
 	// StatConf defines the stat config.
 	StatConf = internal.StatConf
+	// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
+	ServerSpecifiedTimeoutConf = internal.ServerSpecifiedTimeoutConf
 
 	// A RpcClientConf is a rpc client config.
 	RpcClientConf struct {
@@ -45,6 +47,8 @@ type (
 		// grpc health check switch
 		Health      bool `json:",default=true"`
 		Middlewares ServerMiddlewaresConf
+		// setting specified timeout for gRPC method
+		SpecifiedTimeouts []ServerSpecifiedTimeoutConf `json:",optional"`
 	}
 )
 

+ 25 - 2
zrpc/internal/clientinterceptors/timeoutinterceptor.go

@@ -11,13 +11,36 @@ import (
 func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
 	return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
 		invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
-		if timeout <= 0 {
+		t := getTimeoutByCallOptions(opts, timeout)
+		if t <= 0 {
 			return invoker(ctx, method, req, reply, cc, opts...)
 		}
 
-		ctx, cancel := context.WithTimeout(ctx, timeout)
+		ctx, cancel := context.WithTimeout(ctx, t)
 		defer cancel()
 
 		return invoker(ctx, method, req, reply, cc, opts...)
 	}
 }
+
+func getTimeoutByCallOptions(callOptions []grpc.CallOption, defaultTimeout time.Duration) time.Duration {
+	for _, callOption := range callOptions {
+		if o, ok := callOption.(TimeoutCallOption); ok {
+			return o.timeout
+		}
+	}
+
+	return defaultTimeout
+}
+
+type TimeoutCallOption struct {
+	grpc.EmptyCallOption
+
+	timeout time.Duration
+}
+
+func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
+	return TimeoutCallOption{
+		timeout: timeout,
+	}
+}

+ 71 - 0
zrpc/internal/clientinterceptors/timeoutinterceptor_test.go

@@ -66,3 +66,74 @@ func TestTimeoutInterceptor_panic(t *testing.T) {
 		})
 	}
 }
+
+func TestTimeoutInterceptor_TimeoutCallOption(t *testing.T) {
+	type args struct {
+		interceptorTimeout time.Duration
+		callOptionTimeout  time.Duration
+		runTime            time.Duration
+	}
+	var tests = []struct {
+		name    string
+		args    args
+		wantErr error
+	}{
+		{
+			name: "do not timeout without call option timeout",
+			args: args{
+				interceptorTimeout: time.Second,
+				runTime:            time.Millisecond * 50,
+			},
+			wantErr: nil,
+		},
+		{
+			name: "timeout without call option timeout",
+			args: args{
+				interceptorTimeout: time.Second,
+				runTime:            time.Second * 2,
+			},
+			wantErr: context.DeadlineExceeded,
+		},
+		{
+			name: "do not timeout with call option timeout",
+			args: args{
+				interceptorTimeout: time.Second,
+				callOptionTimeout:  time.Second * 3,
+				runTime:            time.Second * 2,
+			},
+			wantErr: nil,
+		},
+	}
+	for _, tt := range tests {
+		tt := tt
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+
+			interceptor := TimeoutInterceptor(tt.args.interceptorTimeout)
+
+			cc := new(grpc.ClientConn)
+			var co []grpc.CallOption
+			if tt.args.callOptionTimeout > 0 {
+				co = append(co, WithTimeoutCallOption(tt.args.callOptionTimeout))
+			}
+
+			err := interceptor(context.Background(), "/foo", nil, nil, cc,
+				func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
+					opts ...grpc.CallOption) error {
+					timer := time.NewTimer(tt.args.runTime)
+					defer timer.Stop()
+
+					select {
+					case <-timer.C:
+						return nil
+					case <-ctx.Done():
+						return ctx.Err()
+					}
+				}, co...,
+			)
+			t.Logf("error: %+v", err)
+
+			assert.EqualValues(t, tt.wantErr, err)
+		})
+	}
+}

+ 2 - 0
zrpc/internal/config.go

@@ -24,4 +24,6 @@ type (
 		Prometheus bool     `json:",default=true"`
 		Breaker    bool     `json:",default=true"`
 	}
+
+	ServerSpecifiedTimeoutConf = serverinterceptors.ServerSpecifiedTimeoutConf
 )

+ 39 - 2
zrpc/internal/serverinterceptors/timeoutinterceptor.go

@@ -14,11 +14,23 @@ import (
 	"google.golang.org/grpc/status"
 )
 
+type (
+	// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
+	ServerSpecifiedTimeoutConf struct {
+		FullMethod string
+		Timeout    time.Duration
+	}
+
+	specifiedTimeoutCache map[string]time.Duration
+)
+
 // UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
-func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
+func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerSpecifiedTimeoutConf) grpc.UnaryServerInterceptor {
+	cache := cacheSpecifiedTimeout(specifiedTimeouts)
 	return func(ctx context.Context, req any, info *grpc.UnaryServerInfo,
 		handler grpc.UnaryHandler) (any, error) {
-		ctx, cancel := context.WithTimeout(ctx, timeout)
+		t := getTimeoutByUnaryServerInfo(info, timeout, cache)
+		ctx, cancel := context.WithTimeout(ctx, t)
 		defer cancel()
 
 		var resp any
@@ -59,3 +71,28 @@ func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor
 		}
 	}
 }
+
+func cacheSpecifiedTimeout(specifiedTimeouts []ServerSpecifiedTimeoutConf) specifiedTimeoutCache {
+	cache := make(specifiedTimeoutCache, len(specifiedTimeouts))
+	for _, st := range specifiedTimeouts {
+		if st.FullMethod != "" {
+			cache[st.FullMethod] = st.Timeout
+		}
+	}
+
+	return cache
+}
+
+func getTimeoutByUnaryServerInfo(info *grpc.UnaryServerInfo, defaultTimeout time.Duration, specifiedTimeout specifiedTimeoutCache) time.Duration {
+	if ts, ok := info.Server.(TimeoutStrategy); ok {
+		return ts.GetTimeoutByFullMethod(info.FullMethod, defaultTimeout)
+	} else if v, ok := specifiedTimeout[info.FullMethod]; ok {
+		return v
+	}
+
+	return defaultTimeout
+}
+
+type TimeoutStrategy interface {
+	GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration
+}

+ 173 - 2
zrpc/internal/serverinterceptors/timeoutinterceptor_test.go

@@ -12,6 +12,11 @@ import (
 	"google.golang.org/grpc/status"
 )
 
+var (
+	deadlineExceededErr = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
+	canceledErr         = status.Error(codes.Canceled, context.Canceled.Error())
+)
+
 func TestUnaryTimeoutInterceptor(t *testing.T) {
 	interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
 	_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
@@ -68,7 +73,7 @@ func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) {
 		return nil, nil
 	})
 	wg.Wait()
-	assert.EqualValues(t, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()), err)
+	assert.EqualValues(t, deadlineExceededErr, err)
 }
 
 func TestUnaryTimeoutInterceptor_cancel(t *testing.T) {
@@ -88,5 +93,171 @@ func TestUnaryTimeoutInterceptor_cancel(t *testing.T) {
 	})
 
 	wg.Wait()
-	assert.EqualValues(t, status.Error(codes.Canceled, context.Canceled.Error()), err)
+	assert.EqualValues(t, canceledErr, err)
+}
+
+type tempServer struct {
+	timeout time.Duration
+}
+
+func (s *tempServer) run(duration time.Duration) {
+	time.Sleep(duration)
+}
+func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration {
+	if fullMethod == "/" {
+		return defaultTimeout
+	}
+
+	return s.timeout
+}
+
+func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
+	type args struct {
+		interceptorTimeout time.Duration
+		contextTimeout     time.Duration
+		serverTimeout      time.Duration
+		runTime            time.Duration
+
+		fullMethod string
+	}
+	var tests = []struct {
+		name    string
+		args    args
+		wantErr error
+	}{
+		{
+			name: "do not timeout with interceptor timeout",
+			args: args{
+				interceptorTimeout: time.Second,
+				contextTimeout:     time.Second * 5,
+				serverTimeout:      time.Second * 3,
+				runTime:            time.Millisecond * 50,
+				fullMethod:         "/",
+			},
+			wantErr: nil,
+		},
+		{
+			name: "do not timeout with timeout strategy",
+			args: args{
+				interceptorTimeout: time.Second,
+				contextTimeout:     time.Second * 5,
+				serverTimeout:      time.Second * 3,
+				runTime:            time.Second * 2,
+				fullMethod:         "/2s",
+			},
+			wantErr: nil,
+		},
+		{
+			name: "timeout with interceptor timeout",
+			args: args{
+				interceptorTimeout: time.Second,
+				contextTimeout:     time.Second * 5,
+				serverTimeout:      time.Second * 3,
+				runTime:            time.Second * 2,
+				fullMethod:         "/",
+			},
+			wantErr: deadlineExceededErr,
+		},
+	}
+	for _, tt := range tests {
+		tt := tt
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+
+			interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout)
+			ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
+			defer cancel()
+
+			svr := &tempServer{timeout: tt.args.serverTimeout}
+
+			_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
+				Server:     svr,
+				FullMethod: tt.args.fullMethod,
+			}, func(ctx context.Context, req interface{}) (interface{}, error) {
+				svr.run(tt.args.runTime)
+				return nil, nil
+			})
+			t.Logf("error: %+v", err)
+
+			assert.EqualValues(t, tt.wantErr, err)
+		})
+	}
+}
+
+func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
+	type args struct {
+		interceptorTimeout time.Duration
+		contextTimeout     time.Duration
+		method             string
+		methodTimeout      time.Duration
+		runTime            time.Duration
+	}
+	var tests = []struct {
+		name    string
+		args    args
+		wantErr error
+	}{
+		{
+			name: "do not timeout without set timeout for full method",
+			args: args{
+				interceptorTimeout: time.Second,
+				contextTimeout:     time.Second * 5,
+				method:             "/run",
+				runTime:            time.Millisecond * 50,
+			},
+			wantErr: nil,
+		},
+		{
+			name: "do not timeout with set timeout for full method",
+			args: args{
+				interceptorTimeout: time.Second,
+				contextTimeout:     time.Second * 5,
+				method:             "/run/do_not_timeout",
+				methodTimeout:      time.Second * 3,
+				runTime:            time.Second * 2,
+			},
+			wantErr: nil,
+		},
+		{
+			name: "timeout with set timeout for full method",
+			args: args{
+				interceptorTimeout: time.Second,
+				contextTimeout:     time.Second * 5,
+				method:             "/run/timeout",
+				methodTimeout:      time.Millisecond * 100,
+				runTime:            time.Millisecond * 500,
+			},
+			wantErr: deadlineExceededErr,
+		},
+	}
+	for _, tt := range tests {
+		tt := tt
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+
+			var specifiedTimeouts []ServerSpecifiedTimeoutConf
+			if tt.args.methodTimeout > 0 {
+				specifiedTimeouts = []ServerSpecifiedTimeoutConf{
+					{
+						FullMethod: tt.args.method,
+						Timeout:    tt.args.methodTimeout,
+					},
+				}
+			}
+
+			interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout, specifiedTimeouts...)
+			ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
+			defer cancel()
+
+			_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
+				FullMethod: tt.args.method,
+			}, func(ctx context.Context, req interface{}) (interface{}, error) {
+				time.Sleep(tt.args.runTime)
+				return nil, nil
+			})
+			t.Logf("error: %+v", err)
+
+			assert.EqualValues(t, tt.wantErr, err)
+		})
+	}
 }

+ 6 - 2
zrpc/server.go

@@ -131,8 +131,12 @@ func setupInterceptors(svr internal.Server, c RpcServerConf, metrics *stat.Metri
 	}
 
 	if c.Timeout > 0 {
-		svr.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor(
-			time.Duration(c.Timeout) * time.Millisecond))
+		svr.AddUnaryInterceptors(
+			serverinterceptors.UnaryTimeoutInterceptor(
+				time.Duration(c.Timeout)*time.Millisecond,
+				c.SpecifiedTimeouts...,
+			),
+		)
 	}
 
 	if c.Auth {

+ 14 - 0
zrpc/server_test.go

@@ -40,6 +40,12 @@ func TestServer_setupInterceptors(t *testing.T) {
 			Prometheus: true,
 			Breaker:    true,
 		},
+		SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{
+			{
+				FullMethod: "/foo",
+				Timeout:    5 * time.Second,
+			},
+		},
 	}
 	err = setupInterceptors(server, conf, new(stat.Metrics))
 	assert.Nil(t, err)
@@ -75,6 +81,12 @@ func TestServer(t *testing.T) {
 			Prometheus: true,
 			Breaker:    true,
 		},
+		SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{
+			{
+				FullMethod: "/foo",
+				Timeout:    time.Second,
+			},
+		},
 	}, func(server *grpc.Server) {
 	})
 	svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
@@ -105,6 +117,7 @@ func TestServerError(t *testing.T) {
 			Prometheus: true,
 			Breaker:    true,
 		},
+		SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{},
 	}, func(server *grpc.Server) {
 	})
 	assert.NotNil(t, err)
@@ -131,6 +144,7 @@ func TestServer_HasEtcd(t *testing.T) {
 			Prometheus: true,
 			Breaker:    true,
 		},
+		SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{},
 	}, func(server *grpc.Server) {
 	})
 	svr.AddOptions(grpc.ConnectionTimeout(time.Hour))