tracinginterceptor_test.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "sync"
  7. "sync/atomic"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/zeromicro/go-zero/core/trace"
  11. "google.golang.org/grpc"
  12. "google.golang.org/grpc/codes"
  13. "google.golang.org/grpc/metadata"
  14. "google.golang.org/grpc/status"
  15. )
  16. func TestUnaryOpenTracingInterceptor_Disable(t *testing.T) {
  17. _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  18. FullMethod: "/",
  19. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  20. return nil, nil
  21. })
  22. assert.Nil(t, err)
  23. }
  24. func TestUnaryOpenTracingInterceptor_Enabled(t *testing.T) {
  25. trace.StartAgent(trace.Config{
  26. Name: "go-zero-test",
  27. Endpoint: "http://localhost:14268/api/traces",
  28. Batcher: "jaeger",
  29. Sampler: 1.0,
  30. })
  31. _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  32. FullMethod: "/package.TestService.GetUser",
  33. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  34. return nil, nil
  35. })
  36. assert.Nil(t, err)
  37. }
  38. func TestUnaryTracingInterceptor(t *testing.T) {
  39. var run int32
  40. var wg sync.WaitGroup
  41. wg.Add(1)
  42. _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  43. FullMethod: "/",
  44. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  45. defer wg.Done()
  46. atomic.AddInt32(&run, 1)
  47. return nil, nil
  48. })
  49. wg.Wait()
  50. assert.Nil(t, err)
  51. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  52. }
  53. func TestUnaryTracingInterceptor_WithError(t *testing.T) {
  54. tests := []struct {
  55. name string
  56. err error
  57. }{
  58. {
  59. name: "normal error",
  60. err: errors.New("dummy"),
  61. },
  62. {
  63. name: "grpc error",
  64. err: status.Error(codes.DataLoss, "dummy"),
  65. },
  66. }
  67. for _, test := range tests {
  68. test := test
  69. t.Run(test.name, func(t *testing.T) {
  70. t.Parallel()
  71. var wg sync.WaitGroup
  72. wg.Add(1)
  73. var md metadata.MD
  74. ctx := metadata.NewIncomingContext(context.Background(), md)
  75. _, err := UnaryTracingInterceptor(ctx, nil, &grpc.UnaryServerInfo{
  76. FullMethod: "/",
  77. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  78. defer wg.Done()
  79. return nil, test.err
  80. })
  81. wg.Wait()
  82. assert.Equal(t, test.err, err)
  83. })
  84. }
  85. }
  86. func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
  87. var run int32
  88. var wg sync.WaitGroup
  89. wg.Add(1)
  90. var md metadata.MD
  91. ctx := metadata.NewIncomingContext(context.Background(), md)
  92. stream := mockedServerStream{ctx: ctx}
  93. err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
  94. FullMethod: "/foo",
  95. }, func(svr interface{}, stream grpc.ServerStream) error {
  96. defer wg.Done()
  97. atomic.AddInt32(&run, 1)
  98. return nil
  99. })
  100. wg.Wait()
  101. assert.Nil(t, err)
  102. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  103. }
  104. func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
  105. tests := []struct {
  106. name string
  107. err error
  108. }{
  109. {
  110. name: "receive event",
  111. err: status.Error(codes.DataLoss, "dummy"),
  112. },
  113. {
  114. name: "error event",
  115. err: status.Error(codes.DataLoss, "dummy"),
  116. },
  117. }
  118. for _, test := range tests {
  119. test := test
  120. t.Run(test.name, func(t *testing.T) {
  121. t.Parallel()
  122. var wg sync.WaitGroup
  123. wg.Add(1)
  124. var md metadata.MD
  125. ctx := metadata.NewIncomingContext(context.Background(), md)
  126. stream := mockedServerStream{ctx: ctx}
  127. err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
  128. FullMethod: "/foo",
  129. }, func(svr interface{}, stream grpc.ServerStream) error {
  130. defer wg.Done()
  131. return test.err
  132. })
  133. wg.Wait()
  134. assert.Equal(t, test.err, err)
  135. })
  136. }
  137. }
  138. func TestStreamTracingInterceptor_WithError(t *testing.T) {
  139. tests := []struct {
  140. name string
  141. err error
  142. }{
  143. {
  144. name: "normal error",
  145. err: errors.New("dummy"),
  146. },
  147. {
  148. name: "grpc error",
  149. err: status.Error(codes.DataLoss, "dummy"),
  150. },
  151. }
  152. for _, test := range tests {
  153. test := test
  154. t.Run(test.name, func(t *testing.T) {
  155. t.Parallel()
  156. var wg sync.WaitGroup
  157. wg.Add(1)
  158. var md metadata.MD
  159. ctx := metadata.NewIncomingContext(context.Background(), md)
  160. stream := mockedServerStream{ctx: ctx}
  161. err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
  162. FullMethod: "/foo",
  163. }, func(svr interface{}, stream grpc.ServerStream) error {
  164. defer wg.Done()
  165. return test.err
  166. })
  167. wg.Wait()
  168. assert.Equal(t, test.err, err)
  169. })
  170. }
  171. }
  172. func TestClientStream_RecvMsg(t *testing.T) {
  173. tests := []struct {
  174. name string
  175. err error
  176. }{
  177. {
  178. name: "nil error",
  179. },
  180. {
  181. name: "EOF",
  182. err: io.EOF,
  183. },
  184. {
  185. name: "dummy error",
  186. err: errors.New("dummy"),
  187. },
  188. }
  189. for _, test := range tests {
  190. test := test
  191. t.Run(test.name, func(t *testing.T) {
  192. t.Parallel()
  193. stream := wrapServerStream(context.Background(), &mockedServerStream{
  194. ctx: context.Background(),
  195. err: test.err,
  196. })
  197. assert.Equal(t, test.err, stream.RecvMsg(nil))
  198. })
  199. }
  200. }
  201. func TestServerStream_SendMsg(t *testing.T) {
  202. tests := []struct {
  203. name string
  204. err error
  205. }{
  206. {
  207. name: "nil error",
  208. },
  209. {
  210. name: "with error",
  211. err: errors.New("dummy"),
  212. },
  213. }
  214. for _, test := range tests {
  215. test := test
  216. t.Run(test.name, func(t *testing.T) {
  217. t.Parallel()
  218. stream := wrapServerStream(context.Background(), &mockedServerStream{
  219. ctx: context.Background(),
  220. err: test.err,
  221. })
  222. assert.Equal(t, test.err, stream.SendMsg(nil))
  223. })
  224. }
  225. }
  226. type mockedServerStream struct {
  227. ctx context.Context
  228. err error
  229. }
  230. func (m *mockedServerStream) SetHeader(md metadata.MD) error {
  231. panic("implement me")
  232. }
  233. func (m *mockedServerStream) SendHeader(md metadata.MD) error {
  234. panic("implement me")
  235. }
  236. func (m *mockedServerStream) SetTrailer(md metadata.MD) {
  237. panic("implement me")
  238. }
  239. func (m *mockedServerStream) Context() context.Context {
  240. if m.ctx == nil {
  241. return context.Background()
  242. }
  243. return m.ctx
  244. }
  245. func (m *mockedServerStream) SendMsg(v interface{}) error {
  246. return m.err
  247. }
  248. func (m *mockedServerStream) RecvMsg(v interface{}) error {
  249. return m.err
  250. }