kevin 4 gadi atpakaļ
vecāks
revīzija
015e284515

+ 81 - 5
rpcx/internal/chainclientinterceptors_test.go

@@ -2,7 +2,6 @@ package internal
 
 import (
 	"context"
-	"sync/atomic"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
@@ -20,28 +19,105 @@ func TestWithUnaryClientInterceptors(t *testing.T) {
 }
 
 func TestChainStreamClientInterceptors_zero(t *testing.T) {
+	var vals []int
 	interceptors := chainStreamClientInterceptors()
 	_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
 		func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
 			opts ...grpc.CallOption) (grpc.ClientStream, error) {
+			vals = append(vals, 1)
 			return nil, nil
 		})
 	assert.Nil(t, err)
+	assert.ElementsMatch(t, []int{1}, vals)
 }
 
 func TestChainStreamClientInterceptors_one(t *testing.T) {
-	var called int32
+	var vals []int
 	interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc,
 		cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (
 		grpc.ClientStream, error) {
-		atomic.AddInt32(&called, 1)
-		return nil, nil
+		vals = append(vals, 1)
+		return streamer(ctx, desc, cc, method, opts...)
 	})
 	_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
 		func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
 			opts ...grpc.CallOption) (grpc.ClientStream, error) {
+			vals = append(vals, 2)
 			return nil, nil
 		})
 	assert.Nil(t, err)
-	assert.Equal(t, int32(1), atomic.LoadInt32(&called))
+	assert.ElementsMatch(t, []int{1, 2}, vals)
+}
+
+func TestChainStreamClientInterceptors_more(t *testing.T) {
+	var vals []int
+	interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc,
+		cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (
+		grpc.ClientStream, error) {
+		vals = append(vals, 1)
+		return streamer(ctx, desc, cc, method, opts...)
+	}, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
+		streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
+		vals = append(vals, 2)
+		return streamer(ctx, desc, cc, method, opts...)
+	})
+	_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
+		func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
+			opts ...grpc.CallOption) (grpc.ClientStream, error) {
+			vals = append(vals, 3)
+			return nil, nil
+		})
+	assert.Nil(t, err)
+	assert.ElementsMatch(t, []int{1, 2, 3}, vals)
+}
+
+func TestWithUnaryClientInterceptors_zero(t *testing.T) {
+	var vals []int
+	interceptors := chainUnaryClientInterceptors()
+	err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
+		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
+			opts ...grpc.CallOption) error {
+			vals = append(vals, 1)
+			return nil
+		})
+	assert.Nil(t, err)
+	assert.ElementsMatch(t, []int{1}, vals)
+}
+
+func TestWithUnaryClientInterceptors_one(t *testing.T) {
+	var vals []int
+	interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req,
+		reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
+		vals = append(vals, 1)
+		return invoker(ctx, method, req, reply, cc, opts...)
+	})
+	err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
+		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
+			opts ...grpc.CallOption) error {
+			vals = append(vals, 2)
+			return nil
+		})
+	assert.Nil(t, err)
+	assert.ElementsMatch(t, []int{1, 2}, vals)
+}
+
+func TestWithUnaryClientInterceptors_more(t *testing.T) {
+	var vals []int
+	interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req,
+		reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
+		vals = append(vals, 1)
+		return invoker(ctx, method, req, reply, cc, opts...)
+	}, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
+		invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
+		vals = append(vals, 2)
+		return invoker(ctx, method, req, reply, cc, opts...)
+	})
+	err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
+		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
+			opts ...grpc.CallOption) error {
+			vals = append(vals, 3)
+			return nil
+		})
+	assert.Nil(t, err)
+	assert.ElementsMatch(t, []int{1, 2, 3}, vals)
 }

+ 111 - 0
rpcx/internal/chainserverinterceptors_test.go

@@ -0,0 +1,111 @@
+package internal
+
+import (
+	"context"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"google.golang.org/grpc"
+)
+
+func TestWithStreamServerInterceptors(t *testing.T) {
+	opts := WithStreamServerInterceptors()
+	assert.NotNil(t, opts)
+}
+
+func TestWithUnaryServerInterceptors(t *testing.T) {
+	opts := WithUnaryServerInterceptors()
+	assert.NotNil(t, opts)
+}
+
+func TestChainStreamServerInterceptors_zero(t *testing.T) {
+	var vals []int
+	interceptors := chainStreamServerInterceptors()
+	err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
+		vals = append(vals, 1)
+		return nil
+	})
+	assert.Nil(t, err)
+	assert.ElementsMatch(t, []int{1}, vals)
+}
+
+func TestChainStreamServerInterceptors_one(t *testing.T) {
+	var vals []int
+	interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream,
+		info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+		vals = append(vals, 1)
+		return handler(srv, ss)
+	})
+	err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
+		vals = append(vals, 2)
+		return nil
+	})
+	assert.Nil(t, err)
+	assert.ElementsMatch(t, []int{1, 2}, vals)
+}
+
+func TestChainStreamServerInterceptors_more(t *testing.T) {
+	var vals []int
+	interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream,
+		info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+		vals = append(vals, 1)
+		return handler(srv, ss)
+	}, func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+		vals = append(vals, 2)
+		return handler(srv, ss)
+	})
+	err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
+		vals = append(vals, 3)
+		return nil
+	})
+	assert.Nil(t, err)
+	assert.ElementsMatch(t, []int{1, 2, 3}, vals)
+}
+
+func TestChainUnaryServerInterceptors_zero(t *testing.T) {
+	var vals []int
+	interceptors := chainUnaryServerInterceptors()
+	_, err := interceptors(context.Background(), nil, nil,
+		func(ctx context.Context, req interface{}) (interface{}, error) {
+			vals = append(vals, 1)
+			return nil, nil
+		})
+	assert.Nil(t, err)
+	assert.ElementsMatch(t, []int{1}, vals)
+}
+
+func TestChainUnaryServerInterceptors_one(t *testing.T) {
+	var vals []int
+	interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{},
+		info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
+		vals = append(vals, 1)
+		return handler(ctx, req)
+	})
+	_, err := interceptors(context.Background(), nil, nil,
+		func(ctx context.Context, req interface{}) (interface{}, error) {
+			vals = append(vals, 2)
+			return nil, nil
+		})
+	assert.Nil(t, err)
+	assert.ElementsMatch(t, []int{1, 2}, vals)
+}
+
+func TestChainUnaryServerInterceptors_more(t *testing.T) {
+	var vals []int
+	interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{},
+		info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
+		vals = append(vals, 1)
+		return handler(ctx, req)
+	}, func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
+		handler grpc.UnaryHandler) (resp interface{}, err error) {
+		vals = append(vals, 2)
+		return handler(ctx, req)
+	})
+	_, err := interceptors(context.Background(), nil, nil,
+		func(ctx context.Context, req interface{}) (interface{}, error) {
+			vals = append(vals, 3)
+			return nil, nil
+		})
+	assert.Nil(t, err)
+	assert.ElementsMatch(t, []int{1, 2, 3}, vals)
+}