tracinginterceptor_test.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "sync"
  7. "sync/atomic"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/zeromicro/go-zero/core/trace"
  11. "google.golang.org/grpc"
  12. "google.golang.org/grpc/codes"
  13. "google.golang.org/grpc/metadata"
  14. "google.golang.org/grpc/status"
  15. )
  16. func TestUnaryOpenTracingInterceptor_Disable(t *testing.T) {
  17. _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  18. FullMethod: "/",
  19. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  20. return nil, nil
  21. })
  22. assert.Nil(t, err)
  23. }
  24. func TestUnaryOpenTracingInterceptor_Enabled(t *testing.T) {
  25. trace.StartAgent(trace.Config{
  26. Name: "go-zero-test",
  27. Endpoint: "http://localhost:14268/api/traces",
  28. Batcher: "jaeger",
  29. Sampler: 1.0,
  30. })
  31. defer trace.StopAgent()
  32. _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  33. FullMethod: "/package.TestService.GetUser",
  34. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  35. return nil, nil
  36. })
  37. assert.Nil(t, err)
  38. }
  39. func TestUnaryTracingInterceptor(t *testing.T) {
  40. var run int32
  41. var wg sync.WaitGroup
  42. wg.Add(1)
  43. _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  44. FullMethod: "/",
  45. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  46. defer wg.Done()
  47. atomic.AddInt32(&run, 1)
  48. return nil, nil
  49. })
  50. wg.Wait()
  51. assert.Nil(t, err)
  52. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  53. }
  54. func TestUnaryTracingInterceptor_WithError(t *testing.T) {
  55. tests := []struct {
  56. name string
  57. err error
  58. }{
  59. {
  60. name: "normal error",
  61. err: errors.New("dummy"),
  62. },
  63. {
  64. name: "grpc error",
  65. err: status.Error(codes.DataLoss, "dummy"),
  66. },
  67. }
  68. for _, test := range tests {
  69. test := test
  70. t.Run(test.name, func(t *testing.T) {
  71. t.Parallel()
  72. var wg sync.WaitGroup
  73. wg.Add(1)
  74. var md metadata.MD
  75. ctx := metadata.NewIncomingContext(context.Background(), md)
  76. _, err := UnaryTracingInterceptor(ctx, nil, &grpc.UnaryServerInfo{
  77. FullMethod: "/",
  78. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  79. defer wg.Done()
  80. return nil, test.err
  81. })
  82. wg.Wait()
  83. assert.Equal(t, test.err, err)
  84. })
  85. }
  86. }
  87. func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
  88. var run int32
  89. var wg sync.WaitGroup
  90. wg.Add(1)
  91. var md metadata.MD
  92. ctx := metadata.NewIncomingContext(context.Background(), md)
  93. stream := mockedServerStream{ctx: ctx}
  94. err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
  95. FullMethod: "/foo",
  96. }, func(svr interface{}, stream grpc.ServerStream) error {
  97. defer wg.Done()
  98. atomic.AddInt32(&run, 1)
  99. return nil
  100. })
  101. wg.Wait()
  102. assert.Nil(t, err)
  103. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  104. }
  105. func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
  106. tests := []struct {
  107. name string
  108. err error
  109. }{
  110. {
  111. name: "receive event",
  112. err: status.Error(codes.DataLoss, "dummy"),
  113. },
  114. {
  115. name: "error event",
  116. err: status.Error(codes.DataLoss, "dummy"),
  117. },
  118. }
  119. for _, test := range tests {
  120. test := test
  121. t.Run(test.name, func(t *testing.T) {
  122. t.Parallel()
  123. var wg sync.WaitGroup
  124. wg.Add(1)
  125. var md metadata.MD
  126. ctx := metadata.NewIncomingContext(context.Background(), md)
  127. stream := mockedServerStream{ctx: ctx}
  128. err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
  129. FullMethod: "/foo",
  130. }, func(svr interface{}, stream grpc.ServerStream) error {
  131. defer wg.Done()
  132. return test.err
  133. })
  134. wg.Wait()
  135. assert.Equal(t, test.err, err)
  136. })
  137. }
  138. }
  139. func TestStreamTracingInterceptor_WithError(t *testing.T) {
  140. tests := []struct {
  141. name string
  142. err error
  143. }{
  144. {
  145. name: "normal error",
  146. err: errors.New("dummy"),
  147. },
  148. {
  149. name: "grpc error",
  150. err: status.Error(codes.DataLoss, "dummy"),
  151. },
  152. }
  153. for _, test := range tests {
  154. test := test
  155. t.Run(test.name, func(t *testing.T) {
  156. t.Parallel()
  157. var wg sync.WaitGroup
  158. wg.Add(1)
  159. var md metadata.MD
  160. ctx := metadata.NewIncomingContext(context.Background(), md)
  161. stream := mockedServerStream{ctx: ctx}
  162. err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
  163. FullMethod: "/foo",
  164. }, func(svr interface{}, stream grpc.ServerStream) error {
  165. defer wg.Done()
  166. return test.err
  167. })
  168. wg.Wait()
  169. assert.Equal(t, test.err, err)
  170. })
  171. }
  172. }
  173. func TestClientStream_RecvMsg(t *testing.T) {
  174. tests := []struct {
  175. name string
  176. err error
  177. }{
  178. {
  179. name: "nil error",
  180. },
  181. {
  182. name: "EOF",
  183. err: io.EOF,
  184. },
  185. {
  186. name: "dummy error",
  187. err: errors.New("dummy"),
  188. },
  189. }
  190. for _, test := range tests {
  191. test := test
  192. t.Run(test.name, func(t *testing.T) {
  193. t.Parallel()
  194. stream := wrapServerStream(context.Background(), &mockedServerStream{
  195. ctx: context.Background(),
  196. err: test.err,
  197. })
  198. assert.Equal(t, test.err, stream.RecvMsg(nil))
  199. })
  200. }
  201. }
  202. func TestServerStream_SendMsg(t *testing.T) {
  203. tests := []struct {
  204. name string
  205. err error
  206. }{
  207. {
  208. name: "nil error",
  209. },
  210. {
  211. name: "with error",
  212. err: errors.New("dummy"),
  213. },
  214. }
  215. for _, test := range tests {
  216. test := test
  217. t.Run(test.name, func(t *testing.T) {
  218. t.Parallel()
  219. stream := wrapServerStream(context.Background(), &mockedServerStream{
  220. ctx: context.Background(),
  221. err: test.err,
  222. })
  223. assert.Equal(t, test.err, stream.SendMsg(nil))
  224. })
  225. }
  226. }
  227. type mockedServerStream struct {
  228. ctx context.Context
  229. err error
  230. }
  231. func (m *mockedServerStream) SetHeader(md metadata.MD) error {
  232. panic("implement me")
  233. }
  234. func (m *mockedServerStream) SendHeader(md metadata.MD) error {
  235. panic("implement me")
  236. }
  237. func (m *mockedServerStream) SetTrailer(md metadata.MD) {
  238. panic("implement me")
  239. }
  240. func (m *mockedServerStream) Context() context.Context {
  241. if m.ctx == nil {
  242. return context.Background()
  243. }
  244. return m.ctx
  245. }
  246. func (m *mockedServerStream) SendMsg(v interface{}) error {
  247. return m.err
  248. }
  249. func (m *mockedServerStream) RecvMsg(v interface{}) error {
  250. return m.err
  251. }