Forráskód Böngészése

chore: add more tests (#3010)

Kevin Wan 2 éve
szülő
commit
fbf129d535

+ 1 - 1
zrpc/internal/balancer/p2c/p2c.go

@@ -72,7 +72,7 @@ type p2cPicker struct {
 	lock  sync.Mutex
 }
 
-func (p *p2cPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
+func (p *p2cPicker) Pick(_ balancer.PickInfo) (balancer.PickResult, error) {
 	p.lock.Lock()
 	defer p.lock.Unlock()
 

+ 9 - 0
zrpc/internal/balancer/p2c/p2c_test.go

@@ -123,6 +123,15 @@ func TestP2cPicker_Pick(t *testing.T) {
 	}
 }
 
+func TestPickerWithEmptyConns(t *testing.T) {
+	var picker p2cPicker
+	_, err := picker.Pick(balancer.PickInfo{
+		FullMethodName: "/",
+		Ctx:            context.Background(),
+	})
+	assert.ErrorIs(t, err, balancer.ErrNoSubConnAvailable)
+}
+
 type mockClientConn struct {
 	// add random string member to avoid map key equality.
 	id string

+ 69 - 0
zrpc/internal/clientinterceptors/durationinterceptor_test.go

@@ -24,13 +24,82 @@ func TestDurationInterceptor(t *testing.T) {
 			err:  errors.New("mock"),
 		},
 	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			cc := new(grpc.ClientConn)
+			err := DurationInterceptor(context.Background(), "/foo", nil, nil, cc,
+				func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
+					opts ...grpc.CallOption) error {
+					return test.err
+				})
+			assert.Equal(t, test.err, err)
+		})
+	}
+
+	DontLogContentForMethod("/foo")
+	t.Cleanup(func() {
+		notLoggingContentMethods.Delete("/foo")
+	})
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			cc := new(grpc.ClientConn)
+			err := DurationInterceptor(context.Background(), "/foo", nil, nil, cc,
+				func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
+					opts ...grpc.CallOption) error {
+					return test.err
+				})
+			assert.Equal(t, test.err, err)
+		})
+	}
+}
+
+func TestDurationInterceptorWithSlowThreshold(t *testing.T) {
+	SetSlowThreshold(time.Microsecond)
+	t.Cleanup(func() {
+		SetSlowThreshold(defaultSlowThreshold)
+	})
+
+	tests := []struct {
+		name string
+		err  error
+	}{
+		{
+			name: "nil",
+			err:  nil,
+		},
+		{
+			name: "with error",
+			err:  errors.New("mock"),
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			cc := new(grpc.ClientConn)
+			err := DurationInterceptor(context.Background(), "/foo", nil, nil, cc,
+				func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
+					opts ...grpc.CallOption) error {
+					time.Sleep(time.Millisecond * 10)
+					return test.err
+				})
+			assert.Equal(t, test.err, err)
+		})
+	}
+
 	DontLogContentForMethod("/foo")
+	t.Cleanup(func() {
+		notLoggingContentMethods.Delete("/foo")
+	})
+
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
 			cc := new(grpc.ClientConn)
 			err := DurationInterceptor(context.Background(), "/foo", nil, nil, cc,
 				func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
 					opts ...grpc.CallOption) error {
+					time.Sleep(time.Millisecond * 10)
 					return test.err
 				})
 			assert.Equal(t, test.err, err)

+ 17 - 0
zrpc/internal/clientinterceptors/tracinginterceptor_test.go

@@ -69,6 +69,23 @@ func TestUnaryTracingInterceptor_WithError(t *testing.T) {
 	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
 }
 
+func TestUnaryTracingInterceptor_WithStatusError(t *testing.T) {
+	var run int32
+	var wg sync.WaitGroup
+	wg.Add(1)
+	cc := new(grpc.ClientConn)
+	err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
+		func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
+			opts ...grpc.CallOption) error {
+			defer wg.Done()
+			atomic.AddInt32(&run, 1)
+			return status.Error(codes.DataLoss, "dummy")
+		})
+	wg.Wait()
+	assert.NotNil(t, err)
+	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
+}
+
 func TestStreamTracingInterceptor(t *testing.T) {
 	var run int32
 	var wg sync.WaitGroup

+ 0 - 5
zrpc/resolver/internal/kube/targetparser.go

@@ -1,7 +1,6 @@
 package kube
 
 import (
-	"fmt"
 	"strconv"
 	"strings"
 
@@ -34,10 +33,6 @@ func ParseTarget(target resolver.Target) (Service, error) {
 	endpoints := targets.GetEndpoints(target)
 	if strings.Contains(endpoints, colon) {
 		segs := strings.SplitN(endpoints, colon, 2)
-		if len(segs) < 2 {
-			return emptyService, fmt.Errorf("bad endpoint: %s", endpoints)
-		}
-
 		service.Name = segs[0]
 		port, err := strconv.Atoi(segs[1])
 		if err != nil {

+ 13 - 7
zrpc/resolver/internal/kube/targetparser_test.go

@@ -51,18 +51,24 @@ func TestParseTarget(t *testing.T) {
 			input:  "k8s://ns1/my-svc:800a",
 			hasErr: true,
 		},
+		{
+			name:   "bad endpoint",
+			input:  "k8s://ns1:800/:",
+			hasErr: true,
+		},
 	}
 
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
 			uri, err := url.Parse(test.input)
-			assert.Nil(t, err)
-			svc, err := ParseTarget(resolver.Target{URL: *uri})
-			if test.hasErr {
-				assert.NotNil(t, err)
-			} else {
-				assert.Nil(t, err)
-				assert.Equal(t, test.expect, svc)
+			if assert.NoError(t, err) {
+				svc, err := ParseTarget(resolver.Target{URL: *uri})
+				if test.hasErr {
+					assert.NotNil(t, err)
+				} else {
+					assert.Nil(t, err)
+					assert.Equal(t, test.expect, svc)
+				}
 			}
 		})
 	}