timeoutinterceptor_test.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "sync"
  5. "testing"
  6. "time"
  7. "github.com/stretchr/testify/assert"
  8. "google.golang.org/grpc"
  9. "google.golang.org/grpc/codes"
  10. "google.golang.org/grpc/status"
  11. )
  12. var (
  13. deadlineExceededErr = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
  14. canceledErr = status.Error(codes.Canceled, context.Canceled.Error())
  15. )
  16. func TestUnaryTimeoutInterceptor(t *testing.T) {
  17. interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
  18. _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  19. FullMethod: "/",
  20. }, func(ctx context.Context, req any) (any, error) {
  21. return nil, nil
  22. })
  23. assert.Nil(t, err)
  24. }
  25. func TestUnaryTimeoutInterceptor_panic(t *testing.T) {
  26. interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
  27. assert.Panics(t, func() {
  28. _, _ = interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  29. FullMethod: "/",
  30. }, func(ctx context.Context, req any) (any, error) {
  31. panic("any")
  32. })
  33. })
  34. }
  35. func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
  36. const timeout = time.Millisecond * 10
  37. interceptor := UnaryTimeoutInterceptor(timeout)
  38. ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
  39. defer cancel()
  40. var wg sync.WaitGroup
  41. wg.Add(1)
  42. _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
  43. FullMethod: "/",
  44. }, func(ctx context.Context, req any) (any, error) {
  45. defer wg.Done()
  46. tm, ok := ctx.Deadline()
  47. assert.True(t, ok)
  48. assert.True(t, tm.Before(time.Now().Add(timeout+time.Millisecond)))
  49. return nil, nil
  50. })
  51. wg.Wait()
  52. assert.Nil(t, err)
  53. }
  54. func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) {
  55. const timeout = time.Millisecond * 10
  56. interceptor := UnaryTimeoutInterceptor(timeout)
  57. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
  58. defer cancel()
  59. var wg sync.WaitGroup
  60. wg.Add(1)
  61. _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
  62. FullMethod: "/",
  63. }, func(ctx context.Context, req any) (any, error) {
  64. defer wg.Done()
  65. time.Sleep(time.Millisecond * 50)
  66. return nil, nil
  67. })
  68. wg.Wait()
  69. assert.EqualValues(t, deadlineExceededErr, err)
  70. }
  71. func TestUnaryTimeoutInterceptor_cancel(t *testing.T) {
  72. const timeout = time.Minute * 10
  73. interceptor := UnaryTimeoutInterceptor(timeout)
  74. ctx, cancel := context.WithCancel(context.Background())
  75. cancel()
  76. var wg sync.WaitGroup
  77. wg.Add(1)
  78. _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
  79. FullMethod: "/",
  80. }, func(ctx context.Context, req any) (any, error) {
  81. defer wg.Done()
  82. time.Sleep(time.Millisecond * 50)
  83. return nil, nil
  84. })
  85. wg.Wait()
  86. assert.EqualValues(t, canceledErr, err)
  87. }
  88. type tempServer struct {
  89. timeout time.Duration
  90. }
  91. func (s *tempServer) run(duration time.Duration) {
  92. time.Sleep(duration)
  93. }
  94. func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
  95. type args struct {
  96. interceptorTimeout time.Duration
  97. contextTimeout time.Duration
  98. serverTimeout time.Duration
  99. runTime time.Duration
  100. fullMethod string
  101. }
  102. var tests = []struct {
  103. name string
  104. args args
  105. wantErr error
  106. }{
  107. {
  108. name: "do not timeout with interceptor timeout",
  109. args: args{
  110. interceptorTimeout: time.Second,
  111. contextTimeout: time.Second * 5,
  112. serverTimeout: time.Second * 3,
  113. runTime: time.Millisecond * 50,
  114. fullMethod: "/",
  115. },
  116. wantErr: nil,
  117. },
  118. {
  119. name: "timeout with interceptor timeout",
  120. args: args{
  121. interceptorTimeout: time.Second,
  122. contextTimeout: time.Second * 5,
  123. serverTimeout: time.Second * 3,
  124. runTime: time.Second * 2,
  125. fullMethod: "/",
  126. },
  127. wantErr: deadlineExceededErr,
  128. },
  129. }
  130. for _, tt := range tests {
  131. tt := tt
  132. t.Run(tt.name, func(t *testing.T) {
  133. t.Parallel()
  134. interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout)
  135. ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
  136. defer cancel()
  137. svr := &tempServer{timeout: tt.args.serverTimeout}
  138. _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
  139. Server: svr,
  140. FullMethod: tt.args.fullMethod,
  141. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  142. svr.run(tt.args.runTime)
  143. return nil, nil
  144. })
  145. t.Logf("error: %+v", err)
  146. assert.EqualValues(t, tt.wantErr, err)
  147. })
  148. }
  149. }
  150. func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
  151. type args struct {
  152. interceptorTimeout time.Duration
  153. contextTimeout time.Duration
  154. method string
  155. methodTimeout time.Duration
  156. runTime time.Duration
  157. }
  158. var tests = []struct {
  159. name string
  160. args args
  161. wantErr error
  162. }{
  163. {
  164. name: "do not timeout without set timeout for full method",
  165. args: args{
  166. interceptorTimeout: time.Second,
  167. contextTimeout: time.Second * 5,
  168. method: "/run",
  169. runTime: time.Millisecond * 50,
  170. },
  171. wantErr: nil,
  172. },
  173. {
  174. name: "do not timeout with set timeout for full method",
  175. args: args{
  176. interceptorTimeout: time.Second,
  177. contextTimeout: time.Second * 5,
  178. method: "/run/do_not_timeout",
  179. methodTimeout: time.Second * 3,
  180. runTime: time.Second * 2,
  181. },
  182. wantErr: nil,
  183. },
  184. {
  185. name: "timeout with set timeout for full method",
  186. args: args{
  187. interceptorTimeout: time.Second,
  188. contextTimeout: time.Second * 5,
  189. method: "/run/timeout",
  190. methodTimeout: time.Millisecond * 100,
  191. runTime: time.Millisecond * 500,
  192. },
  193. wantErr: deadlineExceededErr,
  194. },
  195. }
  196. for _, tt := range tests {
  197. tt := tt
  198. t.Run(tt.name, func(t *testing.T) {
  199. t.Parallel()
  200. var specifiedTimeouts []MethodTimeoutConf
  201. if tt.args.methodTimeout > 0 {
  202. specifiedTimeouts = []MethodTimeoutConf{
  203. {
  204. FullMethod: tt.args.method,
  205. Timeout: tt.args.methodTimeout,
  206. },
  207. }
  208. }
  209. interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout, specifiedTimeouts...)
  210. ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
  211. defer cancel()
  212. _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
  213. FullMethod: tt.args.method,
  214. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  215. time.Sleep(tt.args.runTime)
  216. return nil, nil
  217. })
  218. t.Logf("error: %+v", err)
  219. assert.EqualValues(t, tt.wantErr, err)
  220. })
  221. }
  222. }