浏览代码

add more tests (#1115)

* add more tests

* fix lint errors
Kevin Wan 3 年之前
父节点
当前提交
4c6234f108

+ 52 - 0
zrpc/internal/serverinterceptors/statinterceptor_test.go

@@ -2,11 +2,15 @@ package serverinterceptors
 
 import (
 	"context"
+	"net"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/core/lang"
 	"github.com/tal-tech/go-zero/core/stat"
 	"google.golang.org/grpc"
+	"google.golang.org/grpc/peer"
 )
 
 func TestUnaryStatInterceptor(t *testing.T) {
@@ -30,3 +34,51 @@ func TestUnaryStatInterceptor_crash(t *testing.T) {
 	})
 	assert.NotNil(t, err)
 }
+
+func TestLogDuration(t *testing.T) {
+	addrs, err := net.InterfaceAddrs()
+	assert.Nil(t, err)
+	assert.True(t, len(addrs) > 0)
+
+	tests := []struct {
+		name     string
+		ctx      context.Context
+		req      interface{}
+		duration time.Duration
+	}{
+		{
+			name: "normal",
+			ctx:  context.Background(),
+			req:  "foo",
+		},
+		{
+			name: "bad req",
+			ctx:  context.Background(),
+			req:  make(chan lang.PlaceholderType), // not marshalable
+		},
+		{
+			name:     "timeout",
+			ctx:      context.Background(),
+			req:      "foo",
+			duration: time.Second,
+		},
+		{
+			name: "timeout",
+			ctx: peer.NewContext(context.Background(), &peer.Peer{
+				Addr: addrs[0],
+			}),
+			req: "foo",
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+
+			assert.NotPanics(t, func() {
+				logDuration(test.ctx, "foo", test.req, test.duration)
+			})
+		})
+	}
+}

+ 175 - 2
zrpc/internal/serverinterceptors/tracinginterceptor_test.go

@@ -2,6 +2,8 @@ package serverinterceptors
 
 import (
 	"context"
+	"errors"
+	"io"
 	"sync"
 	"sync/atomic"
 	"testing"
@@ -9,7 +11,9 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/tal-tech/go-zero/core/trace"
 	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/metadata"
+	"google.golang.org/grpc/status"
 )
 
 func TestUnaryOpenTracingInterceptor_Disable(t *testing.T) {
@@ -52,6 +56,42 @@ func TestUnaryTracingInterceptor(t *testing.T) {
 	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
 }
 
+func TestUnaryTracingInterceptor_WithError(t *testing.T) {
+	tests := []struct {
+		name string
+		err  error
+	}{
+		{
+			name: "normal error",
+			err:  errors.New("dummy"),
+		},
+		{
+			name: "grpc error",
+			err:  status.Error(codes.DataLoss, "dummy"),
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+
+			var wg sync.WaitGroup
+			wg.Add(1)
+			var md metadata.MD
+			ctx := metadata.NewIncomingContext(context.Background(), md)
+			_, err := UnaryTracingInterceptor(ctx, nil, &grpc.UnaryServerInfo{
+				FullMethod: "/",
+			}, func(ctx context.Context, req interface{}) (interface{}, error) {
+				defer wg.Done()
+				return nil, test.err
+			})
+			wg.Wait()
+			assert.Equal(t, test.err, err)
+		})
+	}
+}
+
 func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
 	var run int32
 	var wg sync.WaitGroup
@@ -71,8 +111,141 @@ func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
 	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
 }
 
+func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
+	tests := []struct {
+		name string
+		err  error
+	}{
+		{
+			name: "receive event",
+			err:  status.Error(codes.DataLoss, "dummy"),
+		},
+		{
+			name: "error event",
+			err:  status.Error(codes.DataLoss, "dummy"),
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+
+			var wg sync.WaitGroup
+			wg.Add(1)
+			var md metadata.MD
+			ctx := metadata.NewIncomingContext(context.Background(), md)
+			stream := mockedServerStream{ctx: ctx}
+			err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
+				FullMethod: "/foo",
+			}, func(srv interface{}, stream grpc.ServerStream) error {
+				defer wg.Done()
+				return test.err
+			})
+			wg.Wait()
+			assert.Equal(t, test.err, err)
+		})
+	}
+}
+
+func TestStreamTracingInterceptor_WithError(t *testing.T) {
+	tests := []struct {
+		name string
+		err  error
+	}{
+		{
+			name: "normal error",
+			err:  errors.New("dummy"),
+		},
+		{
+			name: "grpc error",
+			err:  status.Error(codes.DataLoss, "dummy"),
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+
+			var wg sync.WaitGroup
+			wg.Add(1)
+			var md metadata.MD
+			ctx := metadata.NewIncomingContext(context.Background(), md)
+			stream := mockedServerStream{ctx: ctx}
+			err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
+				FullMethod: "/foo",
+			}, func(srv interface{}, stream grpc.ServerStream) error {
+				defer wg.Done()
+				return test.err
+			})
+			wg.Wait()
+			assert.Equal(t, test.err, err)
+		})
+	}
+}
+
+func TestClientStream_RecvMsg(t *testing.T) {
+	tests := []struct {
+		name string
+		err  error
+	}{
+		{
+			name: "nil error",
+		},
+		{
+			name: "EOF",
+			err:  io.EOF,
+		},
+		{
+			name: "dummy error",
+			err:  errors.New("dummy"),
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+			stream := wrapServerStream(context.Background(), &mockedServerStream{
+				ctx: context.Background(),
+				err: test.err,
+			})
+			assert.Equal(t, test.err, stream.RecvMsg(nil))
+		})
+	}
+}
+
+func TestServerStream_SendMsg(t *testing.T) {
+	tests := []struct {
+		name string
+		err  error
+	}{
+		{
+			name: "nil error",
+		},
+		{
+			name: "with error",
+			err:  errors.New("dummy"),
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+			stream := wrapServerStream(context.Background(), &mockedServerStream{
+				ctx: context.Background(),
+				err: test.err,
+			})
+			assert.Equal(t, test.err, stream.SendMsg(nil))
+		})
+	}
+}
+
 type mockedServerStream struct {
 	ctx context.Context
+	err error
 }
 
 func (m *mockedServerStream) SetHeader(md metadata.MD) error {
@@ -96,9 +269,9 @@ func (m *mockedServerStream) Context() context.Context {
 }
 
 func (m *mockedServerStream) SendMsg(v interface{}) error {
-	panic("implement me")
+	return m.err
 }
 
 func (m *mockedServerStream) RecvMsg(v interface{}) error {
-	panic("implement me")
+	return m.err
 }