tracinginterceptor_test.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. package clientinterceptors
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "sync"
  7. "sync/atomic"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/zeromicro/go-zero/core/trace"
  11. "google.golang.org/grpc"
  12. "google.golang.org/grpc/codes"
  13. "google.golang.org/grpc/metadata"
  14. "google.golang.org/grpc/status"
  15. )
  16. func TestOpenTracingInterceptor(t *testing.T) {
  17. trace.StartAgent(trace.Config{
  18. Name: "go-zero-test",
  19. Endpoint: "http://localhost:14268/api/traces",
  20. Batcher: "jaeger",
  21. Sampler: 1.0,
  22. })
  23. defer trace.StopAgent()
  24. cc := new(grpc.ClientConn)
  25. ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{})
  26. err := UnaryTracingInterceptor(ctx, "/ListUser", nil, nil, cc,
  27. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  28. opts ...grpc.CallOption) error {
  29. return nil
  30. })
  31. assert.Nil(t, err)
  32. }
  33. func TestUnaryTracingInterceptor(t *testing.T) {
  34. var run int32
  35. var wg sync.WaitGroup
  36. wg.Add(1)
  37. cc := new(grpc.ClientConn)
  38. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  39. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  40. opts ...grpc.CallOption) error {
  41. defer wg.Done()
  42. atomic.AddInt32(&run, 1)
  43. return nil
  44. })
  45. wg.Wait()
  46. assert.Nil(t, err)
  47. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  48. }
  49. func TestUnaryTracingInterceptor_WithError(t *testing.T) {
  50. var run int32
  51. var wg sync.WaitGroup
  52. wg.Add(1)
  53. cc := new(grpc.ClientConn)
  54. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  55. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  56. opts ...grpc.CallOption) error {
  57. defer wg.Done()
  58. atomic.AddInt32(&run, 1)
  59. return errors.New("dummy")
  60. })
  61. wg.Wait()
  62. assert.NotNil(t, err)
  63. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  64. }
  65. func TestUnaryTracingInterceptor_WithStatusError(t *testing.T) {
  66. var run int32
  67. var wg sync.WaitGroup
  68. wg.Add(1)
  69. cc := new(grpc.ClientConn)
  70. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  71. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  72. opts ...grpc.CallOption) error {
  73. defer wg.Done()
  74. atomic.AddInt32(&run, 1)
  75. return status.Error(codes.DataLoss, "dummy")
  76. })
  77. wg.Wait()
  78. assert.NotNil(t, err)
  79. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  80. }
  81. func TestStreamTracingInterceptor(t *testing.T) {
  82. var run int32
  83. var wg sync.WaitGroup
  84. wg.Add(1)
  85. cc := new(grpc.ClientConn)
  86. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  87. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  88. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  89. defer wg.Done()
  90. atomic.AddInt32(&run, 1)
  91. return nil, nil
  92. })
  93. wg.Wait()
  94. assert.Nil(t, err)
  95. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  96. }
  97. func TestStreamTracingInterceptor_FinishWithNormalError(t *testing.T) {
  98. var wg sync.WaitGroup
  99. wg.Add(1)
  100. cc := new(grpc.ClientConn)
  101. ctx, cancel := context.WithCancel(context.Background())
  102. stream, err := StreamTracingInterceptor(ctx, nil, cc, "/foo",
  103. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  104. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  105. defer wg.Done()
  106. return nil, nil
  107. })
  108. wg.Wait()
  109. assert.Nil(t, err)
  110. cancel()
  111. cs := stream.(*clientStream)
  112. <-cs.eventsDone
  113. }
  114. func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
  115. tests := []struct {
  116. name string
  117. event streamEventType
  118. err error
  119. }{
  120. {
  121. name: "receive event",
  122. event: receiveEndEvent,
  123. err: status.Error(codes.DataLoss, "dummy"),
  124. },
  125. {
  126. name: "error event",
  127. event: errorEvent,
  128. err: status.Error(codes.DataLoss, "dummy"),
  129. },
  130. }
  131. for _, test := range tests {
  132. test := test
  133. t.Run(test.name, func(t *testing.T) {
  134. t.Parallel()
  135. var wg sync.WaitGroup
  136. wg.Add(1)
  137. cc := new(grpc.ClientConn)
  138. stream, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  139. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  140. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  141. defer wg.Done()
  142. return &mockedClientStream{
  143. err: errors.New("dummy"),
  144. }, nil
  145. })
  146. wg.Wait()
  147. assert.Nil(t, err)
  148. cs := stream.(*clientStream)
  149. cs.sendStreamEvent(test.event, status.Error(codes.DataLoss, "dummy"))
  150. <-cs.eventsDone
  151. cs.sendStreamEvent(test.event, test.err)
  152. assert.NotNil(t, cs.CloseSend())
  153. })
  154. }
  155. }
  156. func TestStreamTracingInterceptor_WithError(t *testing.T) {
  157. tests := []struct {
  158. name string
  159. err error
  160. }{
  161. {
  162. name: "normal error",
  163. err: errors.New("dummy"),
  164. },
  165. {
  166. name: "grpc error",
  167. err: status.Error(codes.DataLoss, "dummy"),
  168. },
  169. }
  170. for _, test := range tests {
  171. test := test
  172. t.Run(test.name, func(t *testing.T) {
  173. t.Parallel()
  174. var run int32
  175. var wg sync.WaitGroup
  176. wg.Add(1)
  177. cc := new(grpc.ClientConn)
  178. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  179. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  180. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  181. defer wg.Done()
  182. atomic.AddInt32(&run, 1)
  183. return new(mockedClientStream), test.err
  184. })
  185. wg.Wait()
  186. assert.NotNil(t, err)
  187. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  188. })
  189. }
  190. }
  191. func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) {
  192. var run int32
  193. var wg sync.WaitGroup
  194. wg.Add(1)
  195. cc := new(grpc.ClientConn)
  196. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  197. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  198. opts ...grpc.CallOption) error {
  199. defer wg.Done()
  200. atomic.AddInt32(&run, 1)
  201. return nil
  202. })
  203. wg.Wait()
  204. assert.Nil(t, err)
  205. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  206. }
  207. func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
  208. var run int32
  209. var wg sync.WaitGroup
  210. wg.Add(1)
  211. cc := new(grpc.ClientConn)
  212. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  213. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  214. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  215. defer wg.Done()
  216. atomic.AddInt32(&run, 1)
  217. return nil, nil
  218. })
  219. wg.Wait()
  220. assert.Nil(t, err)
  221. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  222. }
  223. func TestClientStream_RecvMsg(t *testing.T) {
  224. tests := []struct {
  225. name string
  226. serverStreams bool
  227. err error
  228. }{
  229. {
  230. name: "nil error",
  231. },
  232. {
  233. name: "EOF",
  234. err: io.EOF,
  235. },
  236. {
  237. name: "dummy error",
  238. err: errors.New("dummy"),
  239. },
  240. {
  241. name: "server streams",
  242. serverStreams: true,
  243. },
  244. }
  245. for _, test := range tests {
  246. test := test
  247. t.Run(test.name, func(t *testing.T) {
  248. t.Parallel()
  249. desc := new(grpc.StreamDesc)
  250. desc.ServerStreams = test.serverStreams
  251. stream := wrapClientStream(context.Background(), &mockedClientStream{
  252. md: nil,
  253. err: test.err,
  254. }, desc)
  255. assert.Equal(t, test.err, stream.RecvMsg(nil))
  256. })
  257. }
  258. }
  259. func TestClientStream_Header(t *testing.T) {
  260. tests := []struct {
  261. name string
  262. err error
  263. }{
  264. {
  265. name: "nil error",
  266. },
  267. {
  268. name: "with error",
  269. err: errors.New("dummy"),
  270. },
  271. }
  272. for _, test := range tests {
  273. test := test
  274. t.Run(test.name, func(t *testing.T) {
  275. t.Parallel()
  276. desc := new(grpc.StreamDesc)
  277. stream := wrapClientStream(context.Background(), &mockedClientStream{
  278. md: metadata.MD{},
  279. err: test.err,
  280. }, desc)
  281. _, err := stream.Header()
  282. assert.Equal(t, test.err, err)
  283. })
  284. }
  285. }
  286. func TestClientStream_SendMsg(t *testing.T) {
  287. tests := []struct {
  288. name string
  289. err error
  290. }{
  291. {
  292. name: "nil error",
  293. },
  294. {
  295. name: "with error",
  296. err: errors.New("dummy"),
  297. },
  298. }
  299. for _, test := range tests {
  300. test := test
  301. t.Run(test.name, func(t *testing.T) {
  302. t.Parallel()
  303. desc := new(grpc.StreamDesc)
  304. stream := wrapClientStream(context.Background(), &mockedClientStream{
  305. md: metadata.MD{},
  306. err: test.err,
  307. }, desc)
  308. assert.Equal(t, test.err, stream.SendMsg(nil))
  309. })
  310. }
  311. }
  312. type mockedClientStream struct {
  313. md metadata.MD
  314. err error
  315. }
  316. func (m *mockedClientStream) Header() (metadata.MD, error) {
  317. return m.md, m.err
  318. }
  319. func (m *mockedClientStream) Trailer() metadata.MD {
  320. panic("implement me")
  321. }
  322. func (m *mockedClientStream) CloseSend() error {
  323. return m.err
  324. }
  325. func (m *mockedClientStream) Context() context.Context {
  326. return context.Background()
  327. }
  328. func (m *mockedClientStream) SendMsg(v any) error {
  329. return m.err
  330. }
  331. func (m *mockedClientStream) RecvMsg(v any) error {
  332. return m.err
  333. }