tracinginterceptor_test.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "sync"
  5. "sync/atomic"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/tal-tech/go-zero/core/trace/tracespec"
  9. "google.golang.org/grpc"
  10. "google.golang.org/grpc/metadata"
  11. )
  12. func TestUnaryTracingInterceptor(t *testing.T) {
  13. interceptor := UnaryTracingInterceptor("foo")
  14. var run int32
  15. var wg sync.WaitGroup
  16. wg.Add(1)
  17. _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  18. FullMethod: "/",
  19. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  20. defer wg.Done()
  21. atomic.AddInt32(&run, 1)
  22. return nil, nil
  23. })
  24. wg.Wait()
  25. assert.Nil(t, err)
  26. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  27. }
  28. func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) {
  29. interceptor := UnaryTracingInterceptor("foo")
  30. var wg sync.WaitGroup
  31. wg.Add(1)
  32. var md metadata.MD
  33. ctx := metadata.NewIncomingContext(context.Background(), md)
  34. _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
  35. FullMethod: "/",
  36. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  37. defer wg.Done()
  38. assert.True(t, len(ctx.Value(tracespec.TracingKey).(tracespec.Trace).TraceId()) > 0)
  39. assert.True(t, len(ctx.Value(tracespec.TracingKey).(tracespec.Trace).SpanId()) > 0)
  40. return nil, nil
  41. })
  42. wg.Wait()
  43. assert.Nil(t, err)
  44. }
  45. func TestStreamTracingInterceptor(t *testing.T) {
  46. interceptor := StreamTracingInterceptor("foo")
  47. var run int32
  48. var wg sync.WaitGroup
  49. wg.Add(1)
  50. err := interceptor(nil, new(mockedServerStream), nil,
  51. func(srv interface{}, stream grpc.ServerStream) error {
  52. defer wg.Done()
  53. atomic.AddInt32(&run, 1)
  54. return nil
  55. })
  56. wg.Wait()
  57. assert.Nil(t, err)
  58. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  59. }
  60. func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
  61. interceptor := StreamTracingInterceptor("foo")
  62. var run int32
  63. var wg sync.WaitGroup
  64. wg.Add(1)
  65. var md metadata.MD
  66. ctx := metadata.NewIncomingContext(context.Background(), md)
  67. stream := mockedServerStream{ctx: ctx}
  68. err := interceptor(nil, &stream, &grpc.StreamServerInfo{
  69. FullMethod: "/foo",
  70. }, func(srv interface{}, stream grpc.ServerStream) error {
  71. defer wg.Done()
  72. atomic.AddInt32(&run, 1)
  73. return nil
  74. })
  75. wg.Wait()
  76. assert.Nil(t, err)
  77. assert.Equal(t, int32(1), atomic.LoadInt32(&run))
  78. }
  79. type mockedServerStream struct {
  80. ctx context.Context
  81. }
  82. func (m *mockedServerStream) SetHeader(md metadata.MD) error {
  83. panic("implement me")
  84. }
  85. func (m *mockedServerStream) SendHeader(md metadata.MD) error {
  86. panic("implement me")
  87. }
  88. func (m *mockedServerStream) SetTrailer(md metadata.MD) {
  89. panic("implement me")
  90. }
  91. func (m *mockedServerStream) Context() context.Context {
  92. if m.ctx == nil {
  93. return context.Background()
  94. }
  95. return m.ctx
  96. }
  97. func (m *mockedServerStream) SendMsg(v interface{}) error {
  98. panic("implement me")
  99. }
  100. func (m *mockedServerStream) RecvMsg(v interface{}) error {
  101. panic("implement me")
  102. }