|
@@ -12,6 +12,11 @@ import (
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+var (
|
|
|
|
+ deadlineExceededErr = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
|
|
|
|
+ canceledErr = status.Error(codes.Canceled, context.Canceled.Error())
|
|
|
|
+)
|
|
|
|
+
|
|
func TestUnaryTimeoutInterceptor(t *testing.T) {
|
|
func TestUnaryTimeoutInterceptor(t *testing.T) {
|
|
interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
|
|
interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
|
|
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
|
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
|
@@ -68,7 +73,7 @@ func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) {
|
|
return nil, nil
|
|
return nil, nil
|
|
})
|
|
})
|
|
wg.Wait()
|
|
wg.Wait()
|
|
- assert.EqualValues(t, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()), err)
|
|
|
|
|
|
+ assert.EqualValues(t, deadlineExceededErr, err)
|
|
}
|
|
}
|
|
|
|
|
|
func TestUnaryTimeoutInterceptor_cancel(t *testing.T) {
|
|
func TestUnaryTimeoutInterceptor_cancel(t *testing.T) {
|
|
@@ -88,5 +93,171 @@ func TestUnaryTimeoutInterceptor_cancel(t *testing.T) {
|
|
})
|
|
})
|
|
|
|
|
|
wg.Wait()
|
|
wg.Wait()
|
|
- assert.EqualValues(t, status.Error(codes.Canceled, context.Canceled.Error()), err)
|
|
|
|
|
|
+ assert.EqualValues(t, canceledErr, err)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+type tempServer struct {
|
|
|
|
+ timeout time.Duration
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (s *tempServer) run(duration time.Duration) {
|
|
|
|
+ time.Sleep(duration)
|
|
|
|
+}
|
|
|
|
+func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration {
|
|
|
|
+ if fullMethod == "/" {
|
|
|
|
+ return defaultTimeout
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return s.timeout
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
|
|
|
|
+ type args struct {
|
|
|
|
+ interceptorTimeout time.Duration
|
|
|
|
+ contextTimeout time.Duration
|
|
|
|
+ serverTimeout time.Duration
|
|
|
|
+ runTime time.Duration
|
|
|
|
+
|
|
|
|
+ fullMethod string
|
|
|
|
+ }
|
|
|
|
+ var tests = []struct {
|
|
|
|
+ name string
|
|
|
|
+ args args
|
|
|
|
+ wantErr error
|
|
|
|
+ }{
|
|
|
|
+ {
|
|
|
|
+ name: "do not timeout with interceptor timeout",
|
|
|
|
+ args: args{
|
|
|
|
+ interceptorTimeout: time.Second,
|
|
|
|
+ contextTimeout: time.Second * 5,
|
|
|
|
+ serverTimeout: time.Second * 3,
|
|
|
|
+ runTime: time.Millisecond * 50,
|
|
|
|
+ fullMethod: "/",
|
|
|
|
+ },
|
|
|
|
+ wantErr: nil,
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "do not timeout with timeout strategy",
|
|
|
|
+ args: args{
|
|
|
|
+ interceptorTimeout: time.Second,
|
|
|
|
+ contextTimeout: time.Second * 5,
|
|
|
|
+ serverTimeout: time.Second * 3,
|
|
|
|
+ runTime: time.Second * 2,
|
|
|
|
+ fullMethod: "/2s",
|
|
|
|
+ },
|
|
|
|
+ wantErr: nil,
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "timeout with interceptor timeout",
|
|
|
|
+ args: args{
|
|
|
|
+ interceptorTimeout: time.Second,
|
|
|
|
+ contextTimeout: time.Second * 5,
|
|
|
|
+ serverTimeout: time.Second * 3,
|
|
|
|
+ runTime: time.Second * 2,
|
|
|
|
+ fullMethod: "/",
|
|
|
|
+ },
|
|
|
|
+ wantErr: deadlineExceededErr,
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+ for _, tt := range tests {
|
|
|
|
+ tt := tt
|
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
|
+ t.Parallel()
|
|
|
|
+
|
|
|
|
+ interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout)
|
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
|
|
|
|
+ defer cancel()
|
|
|
|
+
|
|
|
|
+ svr := &tempServer{timeout: tt.args.serverTimeout}
|
|
|
|
+
|
|
|
|
+ _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
|
|
|
+ Server: svr,
|
|
|
|
+ FullMethod: tt.args.fullMethod,
|
|
|
|
+ }, func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
|
|
+ svr.run(tt.args.runTime)
|
|
|
|
+ return nil, nil
|
|
|
|
+ })
|
|
|
|
+ t.Logf("error: %+v", err)
|
|
|
|
+
|
|
|
|
+ assert.EqualValues(t, tt.wantErr, err)
|
|
|
|
+ })
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
|
|
|
|
+ type args struct {
|
|
|
|
+ interceptorTimeout time.Duration
|
|
|
|
+ contextTimeout time.Duration
|
|
|
|
+ method string
|
|
|
|
+ methodTimeout time.Duration
|
|
|
|
+ runTime time.Duration
|
|
|
|
+ }
|
|
|
|
+ var tests = []struct {
|
|
|
|
+ name string
|
|
|
|
+ args args
|
|
|
|
+ wantErr error
|
|
|
|
+ }{
|
|
|
|
+ {
|
|
|
|
+ name: "do not timeout without set timeout for full method",
|
|
|
|
+ args: args{
|
|
|
|
+ interceptorTimeout: time.Second,
|
|
|
|
+ contextTimeout: time.Second * 5,
|
|
|
|
+ method: "/run",
|
|
|
|
+ runTime: time.Millisecond * 50,
|
|
|
|
+ },
|
|
|
|
+ wantErr: nil,
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "do not timeout with set timeout for full method",
|
|
|
|
+ args: args{
|
|
|
|
+ interceptorTimeout: time.Second,
|
|
|
|
+ contextTimeout: time.Second * 5,
|
|
|
|
+ method: "/run/do_not_timeout",
|
|
|
|
+ methodTimeout: time.Second * 3,
|
|
|
|
+ runTime: time.Second * 2,
|
|
|
|
+ },
|
|
|
|
+ wantErr: nil,
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "timeout with set timeout for full method",
|
|
|
|
+ args: args{
|
|
|
|
+ interceptorTimeout: time.Second,
|
|
|
|
+ contextTimeout: time.Second * 5,
|
|
|
|
+ method: "/run/timeout",
|
|
|
|
+ methodTimeout: time.Millisecond * 100,
|
|
|
|
+ runTime: time.Millisecond * 500,
|
|
|
|
+ },
|
|
|
|
+ wantErr: deadlineExceededErr,
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+ for _, tt := range tests {
|
|
|
|
+ tt := tt
|
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
|
+ t.Parallel()
|
|
|
|
+
|
|
|
|
+ var specifiedTimeouts []ServerSpecifiedTimeoutConf
|
|
|
|
+ if tt.args.methodTimeout > 0 {
|
|
|
|
+ specifiedTimeouts = []ServerSpecifiedTimeoutConf{
|
|
|
|
+ {
|
|
|
|
+ FullMethod: tt.args.method,
|
|
|
|
+ Timeout: tt.args.methodTimeout,
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout, specifiedTimeouts...)
|
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
|
|
|
|
+ defer cancel()
|
|
|
|
+
|
|
|
|
+ _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
|
|
|
+ FullMethod: tt.args.method,
|
|
|
|
+ }, func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
|
|
+ time.Sleep(tt.args.runTime)
|
|
|
|
+ return nil, nil
|
|
|
|
+ })
|
|
|
|
+ t.Logf("error: %+v", err)
|
|
|
|
+
|
|
|
|
+ assert.EqualValues(t, tt.wantErr, err)
|
|
|
|
+ })
|
|
|
|
+ }
|
|
}
|
|
}
|