浏览代码

add more tests (#1114)

Kevin Wan 3 年之前
父节点
当前提交
3cdfcb05f1

+ 18 - 18
zrpc/internal/clientinterceptors/tracinginterceptor.go

@@ -100,6 +100,24 @@ type (
 	}
 )
 
+func (w *clientStream) CloseSend() error {
+	err := w.ClientStream.CloseSend()
+	if err != nil {
+		w.sendStreamEvent(errorEvent, err)
+	}
+
+	return err
+}
+
+func (w *clientStream) Header() (metadata.MD, error) {
+	md, err := w.ClientStream.Header()
+	if err != nil {
+		w.sendStreamEvent(errorEvent, err)
+	}
+
+	return md, err
+}
+
 func (w *clientStream) RecvMsg(m interface{}) error {
 	err := w.ClientStream.RecvMsg(m)
 	if err == nil && !w.desc.ServerStreams {
@@ -127,24 +145,6 @@ func (w *clientStream) SendMsg(m interface{}) error {
 	return err
 }
 
-func (w *clientStream) Header() (metadata.MD, error) {
-	md, err := w.ClientStream.Header()
-	if err != nil {
-		w.sendStreamEvent(errorEvent, err)
-	}
-
-	return md, err
-}
-
-func (w *clientStream) CloseSend() error {
-	err := w.ClientStream.CloseSend()
-	if err != nil {
-		w.sendStreamEvent(errorEvent, err)
-	}
-
-	return err
-}
-
 func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) {
 	select {
 	case <-w.eventsDone:

+ 101 - 4
zrpc/internal/clientinterceptors/tracinginterceptor_test.go

@@ -3,6 +3,7 @@ package clientinterceptors
 import (
 	"context"
 	"errors"
+	"io"
 	"sync"
 	"sync/atomic"
 	"testing"
@@ -24,7 +25,8 @@ func TestOpenTracingInterceptor(t *testing.T) {
 	})
 
 	cc := new(grpc.ClientConn)
-	err := UnaryTracingInterceptor(context.Background(), "/ListUser", nil, nil, cc,
+	ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{})
+	err := UnaryTracingInterceptor(ctx, "/ListUser", nil, nil, cc,
 		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
 			opts ...grpc.CallOption) error {
 			return nil
@@ -220,6 +222,101 @@ func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
 	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
 }
 
+func TestClientStream_RecvMsg(t *testing.T) {
+	tests := []struct {
+		name          string
+		serverStreams bool
+		err           error
+	}{
+		{
+			name: "nil error",
+		},
+		{
+			name: "EOF",
+			err:  io.EOF,
+		},
+		{
+			name: "dummy error",
+			err:  errors.New("dummy"),
+		},
+		{
+			name:          "server streams",
+			serverStreams: true,
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+			desc := new(grpc.StreamDesc)
+			desc.ServerStreams = test.serverStreams
+			stream := wrapClientStream(context.Background(), &mockedClientStream{
+				md:  nil,
+				err: test.err,
+			}, desc)
+			assert.Equal(t, test.err, stream.RecvMsg(nil))
+		})
+	}
+}
+
+func TestClientStream_Header(t *testing.T) {
+	tests := []struct {
+		name string
+		err  error
+	}{
+		{
+			name: "nil error",
+		},
+		{
+			name: "with error",
+			err:  errors.New("dummy"),
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+			desc := new(grpc.StreamDesc)
+			stream := wrapClientStream(context.Background(), &mockedClientStream{
+				md:  metadata.MD{},
+				err: test.err,
+			}, desc)
+			_, err := stream.Header()
+			assert.Equal(t, test.err, err)
+		})
+	}
+}
+
+func TestClientStream_SendMsg(t *testing.T) {
+	tests := []struct {
+		name string
+		err  error
+	}{
+		{
+			name: "nil error",
+		},
+		{
+			name: "with error",
+			err:  errors.New("dummy"),
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+			desc := new(grpc.StreamDesc)
+			stream := wrapClientStream(context.Background(), &mockedClientStream{
+				md:  metadata.MD{},
+				err: test.err,
+			}, desc)
+			assert.Equal(t, test.err, stream.SendMsg(nil))
+		})
+	}
+}
+
 type mockedClientStream struct {
 	md  metadata.MD
 	err error
@@ -238,13 +335,13 @@ func (m *mockedClientStream) CloseSend() error {
 }
 
 func (m *mockedClientStream) Context() context.Context {
-	panic("implement me")
+	return context.Background()
 }
 
 func (m *mockedClientStream) SendMsg(v interface{}) error {
-	panic("implement me")
+	return m.err
 }
 
 func (m *mockedClientStream) RecvMsg(v interface{}) error {
-	panic("implement me")
+	return m.err
 }