Răsfoiți Sursa

change grpc interceptor to chain interceptor (#200)

* change grpc interceptor to chain interceptor

* change server rpc interceptors, del testing code
SunJun 4 ani în urmă
părinte
comite
0a2c2d1eca

+ 3 - 73
zrpc/internal/chainclientinterceptors.go

@@ -1,83 +1,13 @@
 package internal
 
 import (
-	"context"
-
 	"google.golang.org/grpc"
 )
 
 func WithStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.DialOption {
-	return grpc.WithStreamInterceptor(chainStreamClientInterceptors(interceptors...))
+	return grpc.WithChainStreamInterceptor(interceptors...)
 }
 
 func WithUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.DialOption {
-	return grpc.WithUnaryInterceptor(chainUnaryClientInterceptors(interceptors...))
-}
-
-func chainStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
-	switch len(interceptors) {
-	case 0:
-		return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
-			streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
-			return streamer(ctx, desc, cc, method, opts...)
-		}
-	case 1:
-		return interceptors[0]
-	default:
-		last := len(interceptors) - 1
-		return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn,
-			method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
-			var chainStreamer grpc.Streamer
-			var current int
-
-			chainStreamer = func(curCtx context.Context, curDesc *grpc.StreamDesc, curCc *grpc.ClientConn,
-				curMethod string, curOpts ...grpc.CallOption) (grpc.ClientStream, error) {
-				if current == last {
-					return streamer(curCtx, curDesc, curCc, curMethod, curOpts...)
-				}
-
-				current++
-				clientStream, err := interceptors[current](curCtx, curDesc, curCc, curMethod, chainStreamer, curOpts...)
-				current--
-
-				return clientStream, err
-			}
-
-			return interceptors[0](ctx, desc, cc, method, chainStreamer, opts...)
-		}
-	}
-}
-
-func chainUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
-	switch len(interceptors) {
-	case 0:
-		return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
-			invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
-			return invoker(ctx, method, req, reply, cc, opts...)
-		}
-	case 1:
-		return interceptors[0]
-	default:
-		last := len(interceptors) - 1
-		return func(ctx context.Context, method string, req, reply interface{},
-			cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
-			var chainInvoker grpc.UnaryInvoker
-			var current int
-
-			chainInvoker = func(curCtx context.Context, curMethod string, curReq, curReply interface{},
-				curCc *grpc.ClientConn, curOpts ...grpc.CallOption) error {
-				if current == last {
-					return invoker(curCtx, curMethod, curReq, curReply, curCc, curOpts...)
-				}
-
-				current++
-				err := interceptors[current](curCtx, curMethod, curReq, curReply, curCc, chainInvoker, curOpts...)
-				current--
-
-				return err
-			}
-
-			return interceptors[0](ctx, method, req, reply, cc, chainInvoker, opts...)
-		}
-	}
-}
+	return grpc.WithChainUnaryInterceptor(interceptors...)
+}

+ 1 - 107
zrpc/internal/chainclientinterceptors_test.go

@@ -1,11 +1,9 @@
 package internal
 
 import (
-	"context"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
-	"google.golang.org/grpc"
 )
 
 func TestWithStreamClientInterceptors(t *testing.T) {
@@ -16,108 +14,4 @@ func TestWithStreamClientInterceptors(t *testing.T) {
 func TestWithUnaryClientInterceptors(t *testing.T) {
 	opts := WithUnaryClientInterceptors()
 	assert.NotNil(t, opts)
-}
-
-func TestChainStreamClientInterceptors_zero(t *testing.T) {
-	var vals []int
-	interceptors := chainStreamClientInterceptors()
-	_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
-		func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
-			opts ...grpc.CallOption) (grpc.ClientStream, error) {
-			vals = append(vals, 1)
-			return nil, nil
-		})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1}, vals)
-}
-
-func TestChainStreamClientInterceptors_one(t *testing.T) {
-	var vals []int
-	interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc,
-		cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (
-		grpc.ClientStream, error) {
-		vals = append(vals, 1)
-		return streamer(ctx, desc, cc, method, opts...)
-	})
-	_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
-		func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
-			opts ...grpc.CallOption) (grpc.ClientStream, error) {
-			vals = append(vals, 2)
-			return nil, nil
-		})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1, 2}, vals)
-}
-
-func TestChainStreamClientInterceptors_more(t *testing.T) {
-	var vals []int
-	interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc,
-		cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (
-		grpc.ClientStream, error) {
-		vals = append(vals, 1)
-		return streamer(ctx, desc, cc, method, opts...)
-	}, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
-		streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
-		vals = append(vals, 2)
-		return streamer(ctx, desc, cc, method, opts...)
-	})
-	_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
-		func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
-			opts ...grpc.CallOption) (grpc.ClientStream, error) {
-			vals = append(vals, 3)
-			return nil, nil
-		})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1, 2, 3}, vals)
-}
-
-func TestWithUnaryClientInterceptors_zero(t *testing.T) {
-	var vals []int
-	interceptors := chainUnaryClientInterceptors()
-	err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
-		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
-			opts ...grpc.CallOption) error {
-			vals = append(vals, 1)
-			return nil
-		})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1}, vals)
-}
-
-func TestWithUnaryClientInterceptors_one(t *testing.T) {
-	var vals []int
-	interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req,
-		reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
-		vals = append(vals, 1)
-		return invoker(ctx, method, req, reply, cc, opts...)
-	})
-	err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
-		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
-			opts ...grpc.CallOption) error {
-			vals = append(vals, 2)
-			return nil
-		})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1, 2}, vals)
-}
-
-func TestWithUnaryClientInterceptors_more(t *testing.T) {
-	var vals []int
-	interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req,
-		reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
-		vals = append(vals, 1)
-		return invoker(ctx, method, req, reply, cc, opts...)
-	}, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
-		invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
-		vals = append(vals, 2)
-		return invoker(ctx, method, req, reply, cc, opts...)
-	})
-	err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
-		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
-			opts ...grpc.CallOption) error {
-			vals = append(vals, 3)
-			return nil
-		})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1, 2, 3}, vals)
-}
+}

+ 3 - 71
zrpc/internal/chainserverinterceptors.go

@@ -1,81 +1,13 @@
 package internal
 
 import (
-	"context"
-
 	"google.golang.org/grpc"
 )
 
 func WithStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption {
-	return grpc.StreamInterceptor(chainStreamServerInterceptors(interceptors...))
+	return grpc.ChainStreamInterceptor(interceptors...)
 }
 
 func WithUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption {
-	return grpc.UnaryInterceptor(chainUnaryServerInterceptors(interceptors...))
-}
-
-func chainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
-	switch len(interceptors) {
-	case 0:
-		return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
-			handler grpc.StreamHandler) error {
-			return handler(srv, stream)
-		}
-	case 1:
-		return interceptors[0]
-	default:
-		last := len(interceptors) - 1
-		return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
-			handler grpc.StreamHandler) error {
-			var chainHandler grpc.StreamHandler
-			var current int
-
-			chainHandler = func(curSrv interface{}, curStream grpc.ServerStream) error {
-				if current == last {
-					return handler(curSrv, curStream)
-				}
-
-				current++
-				err := interceptors[current](curSrv, curStream, info, chainHandler)
-				current--
-
-				return err
-			}
-
-			return interceptors[0](srv, stream, info, chainHandler)
-		}
-	}
-}
-
-func chainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
-	switch len(interceptors) {
-	case 0:
-		return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
-			interface{}, error) {
-			return handler(ctx, req)
-		}
-	case 1:
-		return interceptors[0]
-	default:
-		last := len(interceptors) - 1
-		return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
-			interface{}, error) {
-			var chainHandler grpc.UnaryHandler
-			var current int
-
-			chainHandler = func(curCtx context.Context, curReq interface{}) (interface{}, error) {
-				if current == last {
-					return handler(curCtx, curReq)
-				}
-
-				current++
-				resp, err := interceptors[current](curCtx, curReq, info, chainHandler)
-				current--
-
-				return resp, err
-			}
-
-			return interceptors[0](ctx, req, info, chainHandler)
-		}
-	}
-}
+	return grpc.ChainUnaryInterceptor(interceptors...)
+}

+ 1 - 95
zrpc/internal/chainserverinterceptors_test.go

@@ -1,11 +1,9 @@
 package internal
 
 import (
-	"context"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
-	"google.golang.org/grpc"
 )
 
 func TestWithStreamServerInterceptors(t *testing.T) {
@@ -16,96 +14,4 @@ func TestWithStreamServerInterceptors(t *testing.T) {
 func TestWithUnaryServerInterceptors(t *testing.T) {
 	opts := WithUnaryServerInterceptors()
 	assert.NotNil(t, opts)
-}
-
-func TestChainStreamServerInterceptors_zero(t *testing.T) {
-	var vals []int
-	interceptors := chainStreamServerInterceptors()
-	err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
-		vals = append(vals, 1)
-		return nil
-	})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1}, vals)
-}
-
-func TestChainStreamServerInterceptors_one(t *testing.T) {
-	var vals []int
-	interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream,
-		info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
-		vals = append(vals, 1)
-		return handler(srv, ss)
-	})
-	err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
-		vals = append(vals, 2)
-		return nil
-	})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1, 2}, vals)
-}
-
-func TestChainStreamServerInterceptors_more(t *testing.T) {
-	var vals []int
-	interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream,
-		info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
-		vals = append(vals, 1)
-		return handler(srv, ss)
-	}, func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
-		vals = append(vals, 2)
-		return handler(srv, ss)
-	})
-	err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
-		vals = append(vals, 3)
-		return nil
-	})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1, 2, 3}, vals)
-}
-
-func TestChainUnaryServerInterceptors_zero(t *testing.T) {
-	var vals []int
-	interceptors := chainUnaryServerInterceptors()
-	_, err := interceptors(context.Background(), nil, nil,
-		func(ctx context.Context, req interface{}) (interface{}, error) {
-			vals = append(vals, 1)
-			return nil, nil
-		})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1}, vals)
-}
-
-func TestChainUnaryServerInterceptors_one(t *testing.T) {
-	var vals []int
-	interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{},
-		info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
-		vals = append(vals, 1)
-		return handler(ctx, req)
-	})
-	_, err := interceptors(context.Background(), nil, nil,
-		func(ctx context.Context, req interface{}) (interface{}, error) {
-			vals = append(vals, 2)
-			return nil, nil
-		})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1, 2}, vals)
-}
-
-func TestChainUnaryServerInterceptors_more(t *testing.T) {
-	var vals []int
-	interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{},
-		info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
-		vals = append(vals, 1)
-		return handler(ctx, req)
-	}, func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
-		handler grpc.UnaryHandler) (resp interface{}, err error) {
-		vals = append(vals, 2)
-		return handler(ctx, req)
-	})
-	_, err := interceptors(context.Background(), nil, nil,
-		func(ctx context.Context, req interface{}) (interface{}, error) {
-			vals = append(vals, 3)
-			return nil, nil
-		})
-	assert.Nil(t, err)
-	assert.ElementsMatch(t, []int{1, 2, 3}, vals)
-}
+}