timeoutinterceptor_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package clientinterceptors
  2. import (
  3. "context"
  4. "strconv"
  5. "sync"
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. "google.golang.org/grpc"
  10. )
  11. func TestTimeoutInterceptor(t *testing.T) {
  12. timeouts := []time.Duration{0, time.Millisecond * 10}
  13. for _, timeout := range timeouts {
  14. t.Run(strconv.FormatInt(int64(timeout), 10), func(t *testing.T) {
  15. interceptor := TimeoutInterceptor(timeout)
  16. cc := new(grpc.ClientConn)
  17. err := interceptor(context.Background(), "/foo", nil, nil, cc,
  18. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  19. opts ...grpc.CallOption) error {
  20. return nil
  21. },
  22. )
  23. assert.Nil(t, err)
  24. })
  25. }
  26. }
  27. func TestTimeoutInterceptor_timeout(t *testing.T) {
  28. const timeout = time.Millisecond * 10
  29. interceptor := TimeoutInterceptor(timeout)
  30. ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
  31. defer cancel()
  32. var wg sync.WaitGroup
  33. wg.Add(1)
  34. cc := new(grpc.ClientConn)
  35. err := interceptor(ctx, "/foo", nil, nil, cc,
  36. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  37. opts ...grpc.CallOption) error {
  38. defer wg.Done()
  39. tm, ok := ctx.Deadline()
  40. assert.True(t, ok)
  41. assert.True(t, tm.Before(time.Now().Add(timeout+time.Millisecond)))
  42. return nil
  43. })
  44. wg.Wait()
  45. assert.Nil(t, err)
  46. }
  47. func TestTimeoutInterceptor_panic(t *testing.T) {
  48. timeouts := []time.Duration{0, time.Millisecond * 10}
  49. for _, timeout := range timeouts {
  50. t.Run(strconv.FormatInt(int64(timeout), 10), func(t *testing.T) {
  51. interceptor := TimeoutInterceptor(timeout)
  52. cc := new(grpc.ClientConn)
  53. assert.Panics(t, func() {
  54. _ = interceptor(context.Background(), "/foo", nil, nil, cc,
  55. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  56. opts ...grpc.CallOption) error {
  57. panic("any")
  58. },
  59. )
  60. })
  61. })
  62. }
  63. }
  64. func TestTimeoutInterceptor_TimeoutCallOption(t *testing.T) {
  65. type args struct {
  66. interceptorTimeout time.Duration
  67. callOptionTimeout time.Duration
  68. runTime time.Duration
  69. }
  70. var tests = []struct {
  71. name string
  72. args args
  73. wantErr error
  74. }{
  75. {
  76. name: "do not timeout without call option timeout",
  77. args: args{
  78. interceptorTimeout: time.Second,
  79. runTime: time.Millisecond * 50,
  80. },
  81. wantErr: nil,
  82. },
  83. {
  84. name: "timeout without call option timeout",
  85. args: args{
  86. interceptorTimeout: time.Second,
  87. runTime: time.Second * 2,
  88. },
  89. wantErr: context.DeadlineExceeded,
  90. },
  91. {
  92. name: "do not timeout with call option timeout",
  93. args: args{
  94. interceptorTimeout: time.Second,
  95. callOptionTimeout: time.Second * 3,
  96. runTime: time.Second * 2,
  97. },
  98. wantErr: nil,
  99. },
  100. }
  101. for _, tt := range tests {
  102. tt := tt
  103. t.Run(tt.name, func(t *testing.T) {
  104. t.Parallel()
  105. interceptor := TimeoutInterceptor(tt.args.interceptorTimeout)
  106. cc := new(grpc.ClientConn)
  107. var co []grpc.CallOption
  108. if tt.args.callOptionTimeout > 0 {
  109. co = append(co, WithCallTimeout(tt.args.callOptionTimeout))
  110. }
  111. err := interceptor(context.Background(), "/foo", nil, nil, cc,
  112. func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
  113. opts ...grpc.CallOption) error {
  114. timer := time.NewTimer(tt.args.runTime)
  115. defer timer.Stop()
  116. select {
  117. case <-timer.C:
  118. return nil
  119. case <-ctx.Done():
  120. return ctx.Err()
  121. }
  122. }, co...,
  123. )
  124. t.Logf("error: %+v", err)
  125. assert.EqualValues(t, tt.wantErr, err)
  126. })
  127. }
  128. }