tracinginterceptor_test.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. package clientinterceptors
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "sync"
  7. "sync/atomic"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/tal-tech/go-zero/core/trace"
  11. "google.golang.org/grpc"
  12. "google.golang.org/grpc/codes"
  13. "google.golang.org/grpc/metadata"
  14. "google.golang.org/grpc/status"
  15. )
  16. func TestOpenTracingInterceptor(t *testing.T) {
  17. trace.StartAgent(trace.Config{
  18. Name: "go-zero-test",
  19. Endpoint: "http://localhost:14268/api/traces",
  20. Batcher: "jaeger",
  21. Sampler: 1.0,
  22. })
  23. cc := new(grpc.ClientConn)
  24. ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{})
  25. err := UnaryTracingInterceptor(ctx, "/ListUser", nil, nil, cc,
  26. func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
  27. opts ...grpc.CallOption) error {
  28. return nil
  29. })
  30. assert.Nil(t, err)
  31. }
  32. func TestUnaryTracingInterceptor(t *testing.T) {
  33. var run int32
  34. var wg sync.WaitGroup
  35. wg.Add(1)
  36. cc := new(grpc.ClientConn)
  37. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  38. func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
  39. opts ...grpc.CallOption) error {
  40. defer wg.Done()
  41. atomic.AddInt32(&run, 1)
  42. return nil
  43. })
  44. wg.Wait()
  45. assert.Nil(t, err)
  46. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  47. }
  48. func TestUnaryTracingInterceptor_WithError(t *testing.T) {
  49. var run int32
  50. var wg sync.WaitGroup
  51. wg.Add(1)
  52. cc := new(grpc.ClientConn)
  53. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  54. func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
  55. opts ...grpc.CallOption) error {
  56. defer wg.Done()
  57. atomic.AddInt32(&run, 1)
  58. return errors.New("dummy")
  59. })
  60. wg.Wait()
  61. assert.NotNil(t, err)
  62. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  63. }
  64. func TestStreamTracingInterceptor(t *testing.T) {
  65. var run int32
  66. var wg sync.WaitGroup
  67. wg.Add(1)
  68. cc := new(grpc.ClientConn)
  69. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  70. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  71. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  72. defer wg.Done()
  73. atomic.AddInt32(&run, 1)
  74. return nil, nil
  75. })
  76. wg.Wait()
  77. assert.Nil(t, err)
  78. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  79. }
  80. func TestStreamTracingInterceptor_FinishWithNormalError(t *testing.T) {
  81. var wg sync.WaitGroup
  82. wg.Add(1)
  83. cc := new(grpc.ClientConn)
  84. ctx, cancel := context.WithCancel(context.Background())
  85. stream, err := StreamTracingInterceptor(ctx, nil, cc, "/foo",
  86. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  87. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  88. defer wg.Done()
  89. return nil, nil
  90. })
  91. wg.Wait()
  92. assert.Nil(t, err)
  93. cancel()
  94. cs := stream.(*clientStream)
  95. <-cs.eventsDone
  96. }
  97. func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
  98. tests := []struct {
  99. name string
  100. event streamEventType
  101. err error
  102. }{
  103. {
  104. name: "receive event",
  105. event: receiveEndEvent,
  106. err: status.Error(codes.DataLoss, "dummy"),
  107. },
  108. {
  109. name: "error event",
  110. event: errorEvent,
  111. err: status.Error(codes.DataLoss, "dummy"),
  112. },
  113. }
  114. for _, test := range tests {
  115. test := test
  116. t.Run(test.name, func(t *testing.T) {
  117. t.Parallel()
  118. var wg sync.WaitGroup
  119. wg.Add(1)
  120. cc := new(grpc.ClientConn)
  121. stream, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  122. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  123. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  124. defer wg.Done()
  125. return &mockedClientStream{
  126. err: errors.New("dummy"),
  127. }, nil
  128. })
  129. wg.Wait()
  130. assert.Nil(t, err)
  131. cs := stream.(*clientStream)
  132. cs.sendStreamEvent(test.event, status.Error(codes.DataLoss, "dummy"))
  133. <-cs.eventsDone
  134. cs.sendStreamEvent(test.event, test.err)
  135. assert.NotNil(t, cs.CloseSend())
  136. })
  137. }
  138. }
  139. func TestStreamTracingInterceptor_WithError(t *testing.T) {
  140. tests := []struct {
  141. name string
  142. err error
  143. }{
  144. {
  145. name: "normal error",
  146. err: errors.New("dummy"),
  147. },
  148. {
  149. name: "grpc error",
  150. err: status.Error(codes.DataLoss, "dummy"),
  151. },
  152. }
  153. for _, test := range tests {
  154. test := test
  155. t.Run(test.name, func(t *testing.T) {
  156. t.Parallel()
  157. var run int32
  158. var wg sync.WaitGroup
  159. wg.Add(1)
  160. cc := new(grpc.ClientConn)
  161. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  162. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  163. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  164. defer wg.Done()
  165. atomic.AddInt32(&run, 1)
  166. return new(mockedClientStream), test.err
  167. })
  168. wg.Wait()
  169. assert.NotNil(t, err)
  170. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  171. })
  172. }
  173. }
  174. func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) {
  175. var run int32
  176. var wg sync.WaitGroup
  177. wg.Add(1)
  178. cc := new(grpc.ClientConn)
  179. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  180. func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
  181. opts ...grpc.CallOption) error {
  182. defer wg.Done()
  183. atomic.AddInt32(&run, 1)
  184. return nil
  185. })
  186. wg.Wait()
  187. assert.Nil(t, err)
  188. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  189. }
  190. func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
  191. var run int32
  192. var wg sync.WaitGroup
  193. wg.Add(1)
  194. cc := new(grpc.ClientConn)
  195. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  196. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  197. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  198. defer wg.Done()
  199. atomic.AddInt32(&run, 1)
  200. return nil, nil
  201. })
  202. wg.Wait()
  203. assert.Nil(t, err)
  204. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  205. }
  206. func TestClientStream_RecvMsg(t *testing.T) {
  207. tests := []struct {
  208. name string
  209. serverStreams bool
  210. err error
  211. }{
  212. {
  213. name: "nil error",
  214. },
  215. {
  216. name: "EOF",
  217. err: io.EOF,
  218. },
  219. {
  220. name: "dummy error",
  221. err: errors.New("dummy"),
  222. },
  223. {
  224. name: "server streams",
  225. serverStreams: true,
  226. },
  227. }
  228. for _, test := range tests {
  229. test := test
  230. t.Run(test.name, func(t *testing.T) {
  231. t.Parallel()
  232. desc := new(grpc.StreamDesc)
  233. desc.ServerStreams = test.serverStreams
  234. stream := wrapClientStream(context.Background(), &mockedClientStream{
  235. md: nil,
  236. err: test.err,
  237. }, desc)
  238. assert.Equal(t, test.err, stream.RecvMsg(nil))
  239. })
  240. }
  241. }
  242. func TestClientStream_Header(t *testing.T) {
  243. tests := []struct {
  244. name string
  245. err error
  246. }{
  247. {
  248. name: "nil error",
  249. },
  250. {
  251. name: "with error",
  252. err: errors.New("dummy"),
  253. },
  254. }
  255. for _, test := range tests {
  256. test := test
  257. t.Run(test.name, func(t *testing.T) {
  258. t.Parallel()
  259. desc := new(grpc.StreamDesc)
  260. stream := wrapClientStream(context.Background(), &mockedClientStream{
  261. md: metadata.MD{},
  262. err: test.err,
  263. }, desc)
  264. _, err := stream.Header()
  265. assert.Equal(t, test.err, err)
  266. })
  267. }
  268. }
  269. func TestClientStream_SendMsg(t *testing.T) {
  270. tests := []struct {
  271. name string
  272. err error
  273. }{
  274. {
  275. name: "nil error",
  276. },
  277. {
  278. name: "with error",
  279. err: errors.New("dummy"),
  280. },
  281. }
  282. for _, test := range tests {
  283. test := test
  284. t.Run(test.name, func(t *testing.T) {
  285. t.Parallel()
  286. desc := new(grpc.StreamDesc)
  287. stream := wrapClientStream(context.Background(), &mockedClientStream{
  288. md: metadata.MD{},
  289. err: test.err,
  290. }, desc)
  291. assert.Equal(t, test.err, stream.SendMsg(nil))
  292. })
  293. }
  294. }
  295. type mockedClientStream struct {
  296. md metadata.MD
  297. err error
  298. }
  299. func (m *mockedClientStream) Header() (metadata.MD, error) {
  300. return m.md, m.err
  301. }
  302. func (m *mockedClientStream) Trailer() metadata.MD {
  303. panic("implement me")
  304. }
  305. func (m *mockedClientStream) CloseSend() error {
  306. return m.err
  307. }
  308. func (m *mockedClientStream) Context() context.Context {
  309. return context.Background()
  310. }
  311. func (m *mockedClientStream) SendMsg(v interface{}) error {
  312. return m.err
  313. }
  314. func (m *mockedClientStream) RecvMsg(v interface{}) error {
  315. return m.err
  316. }