tracinginterceptor_test.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. package clientinterceptors
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "sync"
  7. "sync/atomic"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/zeromicro/go-zero/core/trace"
  11. "google.golang.org/grpc"
  12. "google.golang.org/grpc/codes"
  13. "google.golang.org/grpc/metadata"
  14. "google.golang.org/grpc/status"
  15. )
  16. func TestOpenTracingInterceptor(t *testing.T) {
  17. trace.StartAgent(trace.Config{
  18. Name: "go-zero-test",
  19. Endpoint: "http://localhost:14268/api/traces",
  20. Batcher: "jaeger",
  21. Sampler: 1.0,
  22. })
  23. defer trace.StopAgent()
  24. cc := new(grpc.ClientConn)
  25. ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{})
  26. err := UnaryTracingInterceptor(ctx, "/ListUser", nil, nil, cc,
  27. func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
  28. opts ...grpc.CallOption) error {
  29. return nil
  30. })
  31. assert.Nil(t, err)
  32. }
  33. func TestUnaryTracingInterceptor(t *testing.T) {
  34. var run int32
  35. var wg sync.WaitGroup
  36. wg.Add(1)
  37. cc := new(grpc.ClientConn)
  38. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  39. func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
  40. opts ...grpc.CallOption) error {
  41. defer wg.Done()
  42. atomic.AddInt32(&run, 1)
  43. return nil
  44. })
  45. wg.Wait()
  46. assert.Nil(t, err)
  47. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  48. }
  49. func TestUnaryTracingInterceptor_WithError(t *testing.T) {
  50. var run int32
  51. var wg sync.WaitGroup
  52. wg.Add(1)
  53. cc := new(grpc.ClientConn)
  54. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  55. func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
  56. opts ...grpc.CallOption) error {
  57. defer wg.Done()
  58. atomic.AddInt32(&run, 1)
  59. return errors.New("dummy")
  60. })
  61. wg.Wait()
  62. assert.NotNil(t, err)
  63. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  64. }
  65. func TestStreamTracingInterceptor(t *testing.T) {
  66. var run int32
  67. var wg sync.WaitGroup
  68. wg.Add(1)
  69. cc := new(grpc.ClientConn)
  70. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  71. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  72. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  73. defer wg.Done()
  74. atomic.AddInt32(&run, 1)
  75. return nil, nil
  76. })
  77. wg.Wait()
  78. assert.Nil(t, err)
  79. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  80. }
  81. func TestStreamTracingInterceptor_FinishWithNormalError(t *testing.T) {
  82. var wg sync.WaitGroup
  83. wg.Add(1)
  84. cc := new(grpc.ClientConn)
  85. ctx, cancel := context.WithCancel(context.Background())
  86. stream, err := StreamTracingInterceptor(ctx, nil, cc, "/foo",
  87. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  88. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  89. defer wg.Done()
  90. return nil, nil
  91. })
  92. wg.Wait()
  93. assert.Nil(t, err)
  94. cancel()
  95. cs := stream.(*clientStream)
  96. <-cs.eventsDone
  97. }
  98. func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
  99. tests := []struct {
  100. name string
  101. event streamEventType
  102. err error
  103. }{
  104. {
  105. name: "receive event",
  106. event: receiveEndEvent,
  107. err: status.Error(codes.DataLoss, "dummy"),
  108. },
  109. {
  110. name: "error event",
  111. event: errorEvent,
  112. err: status.Error(codes.DataLoss, "dummy"),
  113. },
  114. }
  115. for _, test := range tests {
  116. test := test
  117. t.Run(test.name, func(t *testing.T) {
  118. t.Parallel()
  119. var wg sync.WaitGroup
  120. wg.Add(1)
  121. cc := new(grpc.ClientConn)
  122. stream, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  123. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  124. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  125. defer wg.Done()
  126. return &mockedClientStream{
  127. err: errors.New("dummy"),
  128. }, nil
  129. })
  130. wg.Wait()
  131. assert.Nil(t, err)
  132. cs := stream.(*clientStream)
  133. cs.sendStreamEvent(test.event, status.Error(codes.DataLoss, "dummy"))
  134. <-cs.eventsDone
  135. cs.sendStreamEvent(test.event, test.err)
  136. assert.NotNil(t, cs.CloseSend())
  137. })
  138. }
  139. }
  140. func TestStreamTracingInterceptor_WithError(t *testing.T) {
  141. tests := []struct {
  142. name string
  143. err error
  144. }{
  145. {
  146. name: "normal error",
  147. err: errors.New("dummy"),
  148. },
  149. {
  150. name: "grpc error",
  151. err: status.Error(codes.DataLoss, "dummy"),
  152. },
  153. }
  154. for _, test := range tests {
  155. test := test
  156. t.Run(test.name, func(t *testing.T) {
  157. t.Parallel()
  158. var run int32
  159. var wg sync.WaitGroup
  160. wg.Add(1)
  161. cc := new(grpc.ClientConn)
  162. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  163. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  164. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  165. defer wg.Done()
  166. atomic.AddInt32(&run, 1)
  167. return new(mockedClientStream), test.err
  168. })
  169. wg.Wait()
  170. assert.NotNil(t, err)
  171. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  172. })
  173. }
  174. }
  175. func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) {
  176. var run int32
  177. var wg sync.WaitGroup
  178. wg.Add(1)
  179. cc := new(grpc.ClientConn)
  180. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  181. func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
  182. opts ...grpc.CallOption) error {
  183. defer wg.Done()
  184. atomic.AddInt32(&run, 1)
  185. return nil
  186. })
  187. wg.Wait()
  188. assert.Nil(t, err)
  189. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  190. }
  191. func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
  192. var run int32
  193. var wg sync.WaitGroup
  194. wg.Add(1)
  195. cc := new(grpc.ClientConn)
  196. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  197. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  198. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  199. defer wg.Done()
  200. atomic.AddInt32(&run, 1)
  201. return nil, nil
  202. })
  203. wg.Wait()
  204. assert.Nil(t, err)
  205. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  206. }
  207. func TestClientStream_RecvMsg(t *testing.T) {
  208. tests := []struct {
  209. name string
  210. serverStreams bool
  211. err error
  212. }{
  213. {
  214. name: "nil error",
  215. },
  216. {
  217. name: "EOF",
  218. err: io.EOF,
  219. },
  220. {
  221. name: "dummy error",
  222. err: errors.New("dummy"),
  223. },
  224. {
  225. name: "server streams",
  226. serverStreams: true,
  227. },
  228. }
  229. for _, test := range tests {
  230. test := test
  231. t.Run(test.name, func(t *testing.T) {
  232. t.Parallel()
  233. desc := new(grpc.StreamDesc)
  234. desc.ServerStreams = test.serverStreams
  235. stream := wrapClientStream(context.Background(), &mockedClientStream{
  236. md: nil,
  237. err: test.err,
  238. }, desc)
  239. assert.Equal(t, test.err, stream.RecvMsg(nil))
  240. })
  241. }
  242. }
  243. func TestClientStream_Header(t *testing.T) {
  244. tests := []struct {
  245. name string
  246. err error
  247. }{
  248. {
  249. name: "nil error",
  250. },
  251. {
  252. name: "with error",
  253. err: errors.New("dummy"),
  254. },
  255. }
  256. for _, test := range tests {
  257. test := test
  258. t.Run(test.name, func(t *testing.T) {
  259. t.Parallel()
  260. desc := new(grpc.StreamDesc)
  261. stream := wrapClientStream(context.Background(), &mockedClientStream{
  262. md: metadata.MD{},
  263. err: test.err,
  264. }, desc)
  265. _, err := stream.Header()
  266. assert.Equal(t, test.err, err)
  267. })
  268. }
  269. }
  270. func TestClientStream_SendMsg(t *testing.T) {
  271. tests := []struct {
  272. name string
  273. err error
  274. }{
  275. {
  276. name: "nil error",
  277. },
  278. {
  279. name: "with error",
  280. err: errors.New("dummy"),
  281. },
  282. }
  283. for _, test := range tests {
  284. test := test
  285. t.Run(test.name, func(t *testing.T) {
  286. t.Parallel()
  287. desc := new(grpc.StreamDesc)
  288. stream := wrapClientStream(context.Background(), &mockedClientStream{
  289. md: metadata.MD{},
  290. err: test.err,
  291. }, desc)
  292. assert.Equal(t, test.err, stream.SendMsg(nil))
  293. })
  294. }
  295. }
  296. type mockedClientStream struct {
  297. md metadata.MD
  298. err error
  299. }
  300. func (m *mockedClientStream) Header() (metadata.MD, error) {
  301. return m.md, m.err
  302. }
  303. func (m *mockedClientStream) Trailer() metadata.MD {
  304. panic("implement me")
  305. }
  306. func (m *mockedClientStream) CloseSend() error {
  307. return m.err
  308. }
  309. func (m *mockedClientStream) Context() context.Context {
  310. return context.Background()
  311. }
  312. func (m *mockedClientStream) SendMsg(v interface{}) error {
  313. return m.err
  314. }
  315. func (m *mockedClientStream) RecvMsg(v interface{}) error {
  316. return m.err
  317. }