瀏覽代碼

add more tests (#1113)

Kevin Wan 3 年之前
父節點
當前提交
9f5bfa0088

+ 1 - 1
zrpc/internal/clientinterceptors/tracinginterceptor.go

@@ -67,10 +67,10 @@ func StreamTracingInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *gr
 			s, ok := status.FromError(err)
 			s, ok := status.FromError(err)
 			if ok {
 			if ok {
 				span.SetStatus(codes.Error, s.Message())
 				span.SetStatus(codes.Error, s.Message())
+				span.SetAttributes(ztrace.StatusCodeAttr(s.Code()))
 			} else {
 			} else {
 				span.SetStatus(codes.Error, err.Error())
 				span.SetStatus(codes.Error, err.Error())
 			}
 			}
-			span.SetAttributes(ztrace.StatusCodeAttr(s.Code()))
 		} else {
 		} else {
 			span.SetAttributes(ztrace.StatusCodeAttr(gcodes.OK))
 			span.SetAttributes(ztrace.StatusCodeAttr(gcodes.OK))
 		}
 		}

+ 125 - 7
zrpc/internal/clientinterceptors/tracinginterceptor_test.go

@@ -10,6 +10,9 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"github.com/tal-tech/go-zero/core/trace"
 	"github.com/tal-tech/go-zero/core/trace"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/metadata"
+	"google.golang.org/grpc/status"
 )
 )
 
 
 func TestOpenTracingInterceptor(t *testing.T) {
 func TestOpenTracingInterceptor(t *testing.T) {
@@ -80,21 +83,107 @@ func TestStreamTracingInterceptor(t *testing.T) {
 	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
 	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
 }
 }
 
 
-func TestStreamTracingInterceptor_WithError(t *testing.T) {
-	var run int32
+func TestStreamTracingInterceptor_FinishWithNormalError(t *testing.T) {
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
 	wg.Add(1)
 	wg.Add(1)
 	cc := new(grpc.ClientConn)
 	cc := new(grpc.ClientConn)
-	_, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
+	ctx, cancel := context.WithCancel(context.Background())
+	stream, err := StreamTracingInterceptor(ctx, nil, cc, "/foo",
 		func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
 		func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
 			opts ...grpc.CallOption) (grpc.ClientStream, error) {
 			opts ...grpc.CallOption) (grpc.ClientStream, error) {
 			defer wg.Done()
 			defer wg.Done()
-			atomic.AddInt32(&run, 1)
-			return nil, errors.New("dummy")
+			return nil, nil
 		})
 		})
 	wg.Wait()
 	wg.Wait()
-	assert.NotNil(t, err)
-	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
+	assert.Nil(t, err)
+
+	cancel()
+	cs := stream.(*clientStream)
+	<-cs.eventsDone
+}
+
+func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
+	tests := []struct {
+		name  string
+		event streamEventType
+		err   error
+	}{
+		{
+			name:  "receive event",
+			event: receiveEndEvent,
+			err:   status.Error(codes.DataLoss, "dummy"),
+		},
+		{
+			name:  "error event",
+			event: errorEvent,
+			err:   status.Error(codes.DataLoss, "dummy"),
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+
+			var wg sync.WaitGroup
+			wg.Add(1)
+			cc := new(grpc.ClientConn)
+			stream, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
+				func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
+					opts ...grpc.CallOption) (grpc.ClientStream, error) {
+					defer wg.Done()
+					return &mockedClientStream{
+						err: errors.New("dummy"),
+					}, nil
+				})
+			wg.Wait()
+			assert.Nil(t, err)
+
+			cs := stream.(*clientStream)
+			cs.sendStreamEvent(test.event, status.Error(codes.DataLoss, "dummy"))
+			<-cs.eventsDone
+			cs.sendStreamEvent(test.event, test.err)
+			assert.NotNil(t, cs.CloseSend())
+		})
+	}
+}
+
+func TestStreamTracingInterceptor_WithError(t *testing.T) {
+	tests := []struct {
+		name string
+		err  error
+	}{
+		{
+			name: "normal error",
+			err:  errors.New("dummy"),
+		},
+		{
+			name: "grpc error",
+			err:  status.Error(codes.DataLoss, "dummy"),
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+
+			var run int32
+			var wg sync.WaitGroup
+			wg.Add(1)
+			cc := new(grpc.ClientConn)
+			_, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
+				func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
+					opts ...grpc.CallOption) (grpc.ClientStream, error) {
+					defer wg.Done()
+					atomic.AddInt32(&run, 1)
+					return new(mockedClientStream), test.err
+				})
+			wg.Wait()
+			assert.NotNil(t, err)
+			assert.Equal(t, int32(1), atomic.LoadInt32(&run))
+		})
+	}
 }
 }
 
 
 func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) {
 func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) {
@@ -130,3 +219,32 @@ func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
 	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
 }
 }
+
+type mockedClientStream struct {
+	md  metadata.MD
+	err error
+}
+
+func (m *mockedClientStream) Header() (metadata.MD, error) {
+	return m.md, m.err
+}
+
+func (m *mockedClientStream) Trailer() metadata.MD {
+	panic("implement me")
+}
+
+func (m *mockedClientStream) CloseSend() error {
+	return m.err
+}
+
+func (m *mockedClientStream) Context() context.Context {
+	panic("implement me")
+}
+
+func (m *mockedClientStream) SendMsg(v interface{}) error {
+	panic("implement me")
+}
+
+func (m *mockedClientStream) RecvMsg(v interface{}) error {
+	panic("implement me")
+}