timeoutinterceptor_test.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "sync"
  5. "testing"
  6. "time"
  7. "github.com/stretchr/testify/assert"
  8. "google.golang.org/grpc"
  9. "google.golang.org/grpc/codes"
  10. "google.golang.org/grpc/status"
  11. )
  12. var (
  13. deadlineExceededErr = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
  14. canceledErr = status.Error(codes.Canceled, context.Canceled.Error())
  15. )
  16. func TestUnaryTimeoutInterceptor(t *testing.T) {
  17. interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
  18. _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  19. FullMethod: "/",
  20. }, func(ctx context.Context, req any) (any, error) {
  21. return nil, nil
  22. })
  23. assert.Nil(t, err)
  24. }
  25. func TestUnaryTimeoutInterceptor_panic(t *testing.T) {
  26. interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
  27. assert.Panics(t, func() {
  28. _, _ = interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  29. FullMethod: "/",
  30. }, func(ctx context.Context, req any) (any, error) {
  31. panic("any")
  32. })
  33. })
  34. }
  35. func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
  36. const timeout = time.Millisecond * 10
  37. interceptor := UnaryTimeoutInterceptor(timeout)
  38. ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
  39. defer cancel()
  40. var wg sync.WaitGroup
  41. wg.Add(1)
  42. _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
  43. FullMethod: "/",
  44. }, func(ctx context.Context, req any) (any, error) {
  45. defer wg.Done()
  46. tm, ok := ctx.Deadline()
  47. assert.True(t, ok)
  48. assert.True(t, tm.Before(time.Now().Add(timeout+time.Millisecond)))
  49. return nil, nil
  50. })
  51. wg.Wait()
  52. assert.Nil(t, err)
  53. }
  54. func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) {
  55. const timeout = time.Millisecond * 10
  56. interceptor := UnaryTimeoutInterceptor(timeout)
  57. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
  58. defer cancel()
  59. var wg sync.WaitGroup
  60. wg.Add(1)
  61. _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
  62. FullMethod: "/",
  63. }, func(ctx context.Context, req any) (any, error) {
  64. defer wg.Done()
  65. time.Sleep(time.Millisecond * 50)
  66. return nil, nil
  67. })
  68. wg.Wait()
  69. assert.EqualValues(t, deadlineExceededErr, err)
  70. }
  71. func TestUnaryTimeoutInterceptor_cancel(t *testing.T) {
  72. const timeout = time.Minute * 10
  73. interceptor := UnaryTimeoutInterceptor(timeout)
  74. ctx, cancel := context.WithCancel(context.Background())
  75. cancel()
  76. var wg sync.WaitGroup
  77. wg.Add(1)
  78. _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
  79. FullMethod: "/",
  80. }, func(ctx context.Context, req any) (any, error) {
  81. defer wg.Done()
  82. time.Sleep(time.Millisecond * 50)
  83. return nil, nil
  84. })
  85. wg.Wait()
  86. assert.EqualValues(t, canceledErr, err)
  87. }
  88. type tempServer struct {
  89. timeout time.Duration
  90. }
  91. func (s *tempServer) run(duration time.Duration) {
  92. time.Sleep(duration)
  93. }
  94. func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration {
  95. if fullMethod == "/" {
  96. return defaultTimeout
  97. }
  98. return s.timeout
  99. }
  100. func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
  101. type args struct {
  102. interceptorTimeout time.Duration
  103. contextTimeout time.Duration
  104. serverTimeout time.Duration
  105. runTime time.Duration
  106. fullMethod string
  107. }
  108. var tests = []struct {
  109. name string
  110. args args
  111. wantErr error
  112. }{
  113. {
  114. name: "do not timeout with interceptor timeout",
  115. args: args{
  116. interceptorTimeout: time.Second,
  117. contextTimeout: time.Second * 5,
  118. serverTimeout: time.Second * 3,
  119. runTime: time.Millisecond * 50,
  120. fullMethod: "/",
  121. },
  122. wantErr: nil,
  123. },
  124. {
  125. name: "do not timeout with timeout strategy",
  126. args: args{
  127. interceptorTimeout: time.Second,
  128. contextTimeout: time.Second * 5,
  129. serverTimeout: time.Second * 3,
  130. runTime: time.Second * 2,
  131. fullMethod: "/2s",
  132. },
  133. wantErr: nil,
  134. },
  135. {
  136. name: "timeout with interceptor timeout",
  137. args: args{
  138. interceptorTimeout: time.Second,
  139. contextTimeout: time.Second * 5,
  140. serverTimeout: time.Second * 3,
  141. runTime: time.Second * 2,
  142. fullMethod: "/",
  143. },
  144. wantErr: deadlineExceededErr,
  145. },
  146. }
  147. for _, tt := range tests {
  148. tt := tt
  149. t.Run(tt.name, func(t *testing.T) {
  150. t.Parallel()
  151. interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout)
  152. ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
  153. defer cancel()
  154. svr := &tempServer{timeout: tt.args.serverTimeout}
  155. _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
  156. Server: svr,
  157. FullMethod: tt.args.fullMethod,
  158. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  159. svr.run(tt.args.runTime)
  160. return nil, nil
  161. })
  162. t.Logf("error: %+v", err)
  163. assert.EqualValues(t, tt.wantErr, err)
  164. })
  165. }
  166. }
  167. func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
  168. type args struct {
  169. interceptorTimeout time.Duration
  170. contextTimeout time.Duration
  171. method string
  172. methodTimeout time.Duration
  173. runTime time.Duration
  174. }
  175. var tests = []struct {
  176. name string
  177. args args
  178. wantErr error
  179. }{
  180. {
  181. name: "do not timeout without set timeout for full method",
  182. args: args{
  183. interceptorTimeout: time.Second,
  184. contextTimeout: time.Second * 5,
  185. method: "/run",
  186. runTime: time.Millisecond * 50,
  187. },
  188. wantErr: nil,
  189. },
  190. {
  191. name: "do not timeout with set timeout for full method",
  192. args: args{
  193. interceptorTimeout: time.Second,
  194. contextTimeout: time.Second * 5,
  195. method: "/run/do_not_timeout",
  196. methodTimeout: time.Second * 3,
  197. runTime: time.Second * 2,
  198. },
  199. wantErr: nil,
  200. },
  201. {
  202. name: "timeout with set timeout for full method",
  203. args: args{
  204. interceptorTimeout: time.Second,
  205. contextTimeout: time.Second * 5,
  206. method: "/run/timeout",
  207. methodTimeout: time.Millisecond * 100,
  208. runTime: time.Millisecond * 500,
  209. },
  210. wantErr: deadlineExceededErr,
  211. },
  212. }
  213. for _, tt := range tests {
  214. tt := tt
  215. t.Run(tt.name, func(t *testing.T) {
  216. t.Parallel()
  217. var specifiedTimeouts []ServerSpecifiedTimeoutConf
  218. if tt.args.methodTimeout > 0 {
  219. specifiedTimeouts = []ServerSpecifiedTimeoutConf{
  220. {
  221. FullMethod: tt.args.method,
  222. Timeout: tt.args.methodTimeout,
  223. },
  224. }
  225. }
  226. interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout, specifiedTimeouts...)
  227. ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
  228. defer cancel()
  229. _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
  230. FullMethod: tt.args.method,
  231. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  232. time.Sleep(tt.args.runTime)
  233. return nil, nil
  234. })
  235. t.Logf("error: %+v", err)
  236. assert.EqualValues(t, tt.wantErr, err)
  237. })
  238. }
  239. }