tracinginterceptor_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. package clientinterceptors
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "sync"
  7. "sync/atomic"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. ztrace "github.com/wuntsong-org/go-zero-plus/core/trace"
  11. "github.com/wuntsong-org/go-zero-plus/core/trace/tracetest"
  12. "go.opentelemetry.io/otel/attribute"
  13. tcodes "go.opentelemetry.io/otel/codes"
  14. "go.opentelemetry.io/otel/sdk/trace"
  15. semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
  16. "google.golang.org/grpc"
  17. "google.golang.org/grpc/codes"
  18. "google.golang.org/grpc/metadata"
  19. "google.golang.org/grpc/status"
  20. )
  21. func TestOpenTracingInterceptor(t *testing.T) {
  22. ztrace.StartAgent(ztrace.Config{
  23. Name: "go-zero-test",
  24. Endpoint: "http://localhost:14268/api/traces",
  25. Batcher: "jaeger",
  26. Sampler: 1.0,
  27. })
  28. defer ztrace.StopAgent()
  29. cc := new(grpc.ClientConn)
  30. ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{})
  31. err := UnaryTracingInterceptor(ctx, "/ListUser", nil, nil, cc,
  32. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  33. opts ...grpc.CallOption) error {
  34. return nil
  35. })
  36. assert.Nil(t, err)
  37. }
  38. func TestUnaryTracingInterceptor(t *testing.T) {
  39. t.Run("normal", func(t *testing.T) {
  40. var run int32
  41. cc := new(grpc.ClientConn)
  42. me := tracetest.NewInMemoryExporter(t)
  43. err := UnaryTracingInterceptor(context.Background(), "/proto.Hello/Echo",
  44. nil, nil, cc,
  45. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  46. opts ...grpc.CallOption) error {
  47. atomic.AddInt32(&run, 1)
  48. return nil
  49. })
  50. assert.Nil(t, err)
  51. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  52. assert.Equal(t, 1, len(me.GetSpans()))
  53. span := me.GetSpans()[0].Snapshot()
  54. assert.Equal(t, 2, len(span.Events()))
  55. assert.ElementsMatch(t, []attribute.KeyValue{
  56. ztrace.RPCSystemGRPC,
  57. semconv.RPCServiceKey.String("proto.Hello"),
  58. semconv.RPCMethodKey.String("Echo"),
  59. ztrace.StatusCodeAttr(codes.OK),
  60. }, span.Attributes())
  61. })
  62. t.Run("grpc error status", func(t *testing.T) {
  63. me := tracetest.NewInMemoryExporter(t)
  64. cc := new(grpc.ClientConn)
  65. err := UnaryTracingInterceptor(context.Background(), "/proto.Hello/Echo",
  66. nil, nil, cc,
  67. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  68. opts ...grpc.CallOption) error {
  69. return status.Error(codes.Unknown, "test")
  70. })
  71. assert.Error(t, err)
  72. assert.Equal(t, 1, len(me.GetSpans()))
  73. span := me.GetSpans()[0].Snapshot()
  74. assert.Equal(t, trace.Status{
  75. Code: tcodes.Error,
  76. Description: "test",
  77. }, span.Status())
  78. assert.Equal(t, 2, len(span.Events()))
  79. assert.ElementsMatch(t, []attribute.KeyValue{
  80. ztrace.RPCSystemGRPC,
  81. semconv.RPCServiceKey.String("proto.Hello"),
  82. semconv.RPCMethodKey.String("Echo"),
  83. ztrace.StatusCodeAttr(codes.Unknown),
  84. }, span.Attributes())
  85. })
  86. t.Run("non grpc status error", func(t *testing.T) {
  87. me := tracetest.NewInMemoryExporter(t)
  88. cc := new(grpc.ClientConn)
  89. err := UnaryTracingInterceptor(context.Background(), "/proto.Hello/Echo",
  90. nil, nil, cc,
  91. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  92. opts ...grpc.CallOption) error {
  93. return errors.New("test")
  94. })
  95. assert.Error(t, err)
  96. assert.Equal(t, 1, len(me.GetSpans()))
  97. span := me.GetSpans()[0].Snapshot()
  98. assert.Equal(t, trace.Status{
  99. Code: tcodes.Error,
  100. Description: "test",
  101. }, span.Status())
  102. assert.Equal(t, 2, len(span.Events()))
  103. assert.ElementsMatch(t, []attribute.KeyValue{
  104. ztrace.RPCSystemGRPC,
  105. semconv.RPCServiceKey.String("proto.Hello"),
  106. semconv.RPCMethodKey.String("Echo"),
  107. }, span.Attributes())
  108. })
  109. }
  110. func TestUnaryTracingInterceptor_WithError(t *testing.T) {
  111. var run int32
  112. var wg sync.WaitGroup
  113. wg.Add(1)
  114. cc := new(grpc.ClientConn)
  115. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  116. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  117. opts ...grpc.CallOption) error {
  118. defer wg.Done()
  119. atomic.AddInt32(&run, 1)
  120. return errors.New("dummy")
  121. })
  122. wg.Wait()
  123. assert.NotNil(t, err)
  124. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  125. }
  126. func TestUnaryTracingInterceptor_WithStatusError(t *testing.T) {
  127. var run int32
  128. var wg sync.WaitGroup
  129. wg.Add(1)
  130. cc := new(grpc.ClientConn)
  131. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  132. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  133. opts ...grpc.CallOption) error {
  134. defer wg.Done()
  135. atomic.AddInt32(&run, 1)
  136. return status.Error(codes.DataLoss, "dummy")
  137. })
  138. wg.Wait()
  139. assert.NotNil(t, err)
  140. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  141. }
  142. func TestStreamTracingInterceptor(t *testing.T) {
  143. var run int32
  144. var wg sync.WaitGroup
  145. wg.Add(1)
  146. cc := new(grpc.ClientConn)
  147. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  148. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  149. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  150. defer wg.Done()
  151. atomic.AddInt32(&run, 1)
  152. return nil, nil
  153. })
  154. wg.Wait()
  155. assert.Nil(t, err)
  156. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  157. }
  158. func TestStreamTracingInterceptor_FinishWithNormalError(t *testing.T) {
  159. var wg sync.WaitGroup
  160. wg.Add(1)
  161. cc := new(grpc.ClientConn)
  162. ctx, cancel := context.WithCancel(context.Background())
  163. stream, err := StreamTracingInterceptor(ctx, nil, cc, "/foo",
  164. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  165. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  166. defer wg.Done()
  167. return nil, nil
  168. })
  169. wg.Wait()
  170. assert.Nil(t, err)
  171. cancel()
  172. cs := stream.(*clientStream)
  173. <-cs.eventsDone
  174. }
  175. func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
  176. tests := []struct {
  177. name string
  178. event streamEventType
  179. err error
  180. }{
  181. {
  182. name: "receive event",
  183. event: receiveEndEvent,
  184. err: status.Error(codes.DataLoss, "dummy"),
  185. },
  186. {
  187. name: "error event",
  188. event: errorEvent,
  189. err: status.Error(codes.DataLoss, "dummy"),
  190. },
  191. }
  192. for _, test := range tests {
  193. test := test
  194. t.Run(test.name, func(t *testing.T) {
  195. t.Parallel()
  196. var wg sync.WaitGroup
  197. wg.Add(1)
  198. cc := new(grpc.ClientConn)
  199. stream, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  200. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  201. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  202. defer wg.Done()
  203. return &mockedClientStream{
  204. err: errors.New("dummy"),
  205. }, nil
  206. })
  207. wg.Wait()
  208. assert.Nil(t, err)
  209. cs := stream.(*clientStream)
  210. cs.sendStreamEvent(test.event, status.Error(codes.DataLoss, "dummy"))
  211. <-cs.eventsDone
  212. cs.sendStreamEvent(test.event, test.err)
  213. assert.NotNil(t, cs.CloseSend())
  214. })
  215. }
  216. }
  217. func TestStreamTracingInterceptor_WithError(t *testing.T) {
  218. tests := []struct {
  219. name string
  220. err error
  221. }{
  222. {
  223. name: "normal error",
  224. err: errors.New("dummy"),
  225. },
  226. {
  227. name: "grpc error",
  228. err: status.Error(codes.DataLoss, "dummy"),
  229. },
  230. }
  231. for _, test := range tests {
  232. test := test
  233. t.Run(test.name, func(t *testing.T) {
  234. t.Parallel()
  235. var run int32
  236. var wg sync.WaitGroup
  237. wg.Add(1)
  238. cc := new(grpc.ClientConn)
  239. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  240. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  241. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  242. defer wg.Done()
  243. atomic.AddInt32(&run, 1)
  244. return new(mockedClientStream), test.err
  245. })
  246. wg.Wait()
  247. assert.NotNil(t, err)
  248. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  249. })
  250. }
  251. }
  252. func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) {
  253. var run int32
  254. var wg sync.WaitGroup
  255. wg.Add(1)
  256. cc := new(grpc.ClientConn)
  257. err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
  258. func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
  259. opts ...grpc.CallOption) error {
  260. defer wg.Done()
  261. atomic.AddInt32(&run, 1)
  262. return nil
  263. })
  264. wg.Wait()
  265. assert.Nil(t, err)
  266. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  267. }
  268. func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
  269. var run int32
  270. var wg sync.WaitGroup
  271. wg.Add(1)
  272. cc := new(grpc.ClientConn)
  273. _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
  274. func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  275. opts ...grpc.CallOption) (grpc.ClientStream, error) {
  276. defer wg.Done()
  277. atomic.AddInt32(&run, 1)
  278. return nil, nil
  279. })
  280. wg.Wait()
  281. assert.Nil(t, err)
  282. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  283. }
  284. func TestClientStream_RecvMsg(t *testing.T) {
  285. tests := []struct {
  286. name string
  287. serverStreams bool
  288. err error
  289. }{
  290. {
  291. name: "nil error",
  292. },
  293. {
  294. name: "EOF",
  295. err: io.EOF,
  296. },
  297. {
  298. name: "dummy error",
  299. err: errors.New("dummy"),
  300. },
  301. {
  302. name: "server streams",
  303. serverStreams: true,
  304. },
  305. }
  306. for _, test := range tests {
  307. test := test
  308. t.Run(test.name, func(t *testing.T) {
  309. t.Parallel()
  310. desc := new(grpc.StreamDesc)
  311. desc.ServerStreams = test.serverStreams
  312. stream := wrapClientStream(context.Background(), &mockedClientStream{
  313. md: nil,
  314. err: test.err,
  315. }, desc)
  316. assert.Equal(t, test.err, stream.RecvMsg(nil))
  317. })
  318. }
  319. }
  320. func TestClientStream_Header(t *testing.T) {
  321. tests := []struct {
  322. name string
  323. err error
  324. }{
  325. {
  326. name: "nil error",
  327. },
  328. {
  329. name: "with error",
  330. err: errors.New("dummy"),
  331. },
  332. }
  333. for _, test := range tests {
  334. test := test
  335. t.Run(test.name, func(t *testing.T) {
  336. t.Parallel()
  337. desc := new(grpc.StreamDesc)
  338. stream := wrapClientStream(context.Background(), &mockedClientStream{
  339. md: metadata.MD{},
  340. err: test.err,
  341. }, desc)
  342. _, err := stream.Header()
  343. assert.Equal(t, test.err, err)
  344. })
  345. }
  346. }
  347. func TestClientStream_SendMsg(t *testing.T) {
  348. tests := []struct {
  349. name string
  350. err error
  351. }{
  352. {
  353. name: "nil error",
  354. },
  355. {
  356. name: "with error",
  357. err: errors.New("dummy"),
  358. },
  359. }
  360. for _, test := range tests {
  361. test := test
  362. t.Run(test.name, func(t *testing.T) {
  363. t.Parallel()
  364. desc := new(grpc.StreamDesc)
  365. stream := wrapClientStream(context.Background(), &mockedClientStream{
  366. md: metadata.MD{},
  367. err: test.err,
  368. }, desc)
  369. assert.Equal(t, test.err, stream.SendMsg(nil))
  370. })
  371. }
  372. }
  373. type mockedClientStream struct {
  374. md metadata.MD
  375. err error
  376. }
  377. func (m *mockedClientStream) Header() (metadata.MD, error) {
  378. return m.md, m.err
  379. }
  380. func (m *mockedClientStream) Trailer() metadata.MD {
  381. panic("implement me")
  382. }
  383. func (m *mockedClientStream) CloseSend() error {
  384. return m.err
  385. }
  386. func (m *mockedClientStream) Context() context.Context {
  387. return context.Background()
  388. }
  389. func (m *mockedClientStream) SendMsg(v any) error {
  390. return m.err
  391. }
  392. func (m *mockedClientStream) RecvMsg(v any) error {
  393. return m.err
  394. }