1
0
Эх сурвалжийг харах

feat: converge grpc interceptor processing (#2830)

* feat: converge grpc interceptor processing

* x

* x
MarkJoyMa 2 жил өмнө
parent
commit
dd117ce9cf

+ 6 - 8
zrpc/internal/rpcserver.go

@@ -59,12 +59,10 @@ func (s *rpcServer) Start(register RegisterFn) error {
 		return err
 		return err
 	}
 	}
 
 
-	unaryInterceptors := s.buildUnaryInterceptors()
-	unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...)
-	streamInterceptors := s.buildStreamInterceptors()
-	streamInterceptors = append(streamInterceptors, s.streamInterceptors...)
-	options := append(s.options, grpc.ChainUnaryInterceptor(unaryInterceptors...),
-		grpc.ChainStreamInterceptor(streamInterceptors...))
+	unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.buildUnaryInterceptors()...)
+	streamInterceptorOption := grpc.ChainStreamInterceptor(s.buildStreamInterceptors()...)
+
+	options := append(s.options, unaryInterceptorOption, streamInterceptorOption)
 	server := grpc.NewServer(options...)
 	server := grpc.NewServer(options...)
 	register(server)
 	register(server)
 
 
@@ -102,7 +100,7 @@ func (s *rpcServer) buildStreamInterceptors() []grpc.StreamServerInterceptor {
 		interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor)
 		interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor)
 	}
 	}
 
 
-	return interceptors
+	return append(interceptors, s.streamInterceptors...)
 }
 }
 
 
 func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
 func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
@@ -124,7 +122,7 @@ func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
 		interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor)
 		interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor)
 	}
 	}
 
 
-	return interceptors
+	return append(interceptors, s.unaryInterceptors...)
 }
 }
 
 
 // WithMetrics returns a func that sets metrics to a Server.
 // WithMetrics returns a func that sets metrics to a Server.

+ 113 - 0
zrpc/internal/rpcserver_test.go

@@ -1,6 +1,7 @@
 package internal
 package internal
 
 
 import (
 import (
+	"context"
 	"sync"
 	"sync"
 	"testing"
 	"testing"
 
 
@@ -58,3 +59,115 @@ func TestRpcServer_WithBadAddress(t *testing.T) {
 	})
 	})
 	assert.NotNil(t, err)
 	assert.NotNil(t, err)
 }
 }
+
+func TestRpcServer_buildUnaryInterceptor(t *testing.T) {
+	tests := []struct {
+		name string
+		r    *rpcServer
+		len  int
+	}{
+		{
+			name: "empty",
+			r: &rpcServer{
+				baseRpcServer: &baseRpcServer{},
+			},
+			len: 0,
+		},
+		{
+			name: "custom",
+			r: &rpcServer{
+				baseRpcServer: &baseRpcServer{
+					unaryInterceptors: []grpc.UnaryServerInterceptor{
+						func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
+							handler grpc.UnaryHandler) (interface{}, error) {
+							return nil, nil
+						},
+					},
+				},
+			},
+			len: 1,
+		},
+		{
+			name: "middleware",
+			r: &rpcServer{
+				baseRpcServer: &baseRpcServer{
+					unaryInterceptors: []grpc.UnaryServerInterceptor{
+						func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
+							handler grpc.UnaryHandler) (interface{}, error) {
+							return nil, nil
+						},
+					},
+				},
+				middlewares: ServerMiddlewaresConf{
+					Trace:      true,
+					Recover:    true,
+					Stat:       true,
+					Prometheus: true,
+					Breaker:    true,
+				},
+			},
+			len: 6,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			assert.Equal(t, test.len, len(test.r.buildUnaryInterceptors()))
+		})
+	}
+}
+
+func TestRpcServer_buildStreamInterceptor(t *testing.T) {
+	tests := []struct {
+		name string
+		r    *rpcServer
+		len  int
+	}{
+		{
+			name: "empty",
+			r: &rpcServer{
+				baseRpcServer: &baseRpcServer{},
+			},
+			len: 0,
+		},
+		{
+			name: "custom",
+			r: &rpcServer{
+				baseRpcServer: &baseRpcServer{
+					streamInterceptors: []grpc.StreamServerInterceptor{
+						func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
+							handler grpc.StreamHandler) error {
+							return nil
+						},
+					},
+				},
+			},
+			len: 1,
+		},
+		{
+			name: "middleware",
+			r: &rpcServer{
+				baseRpcServer: &baseRpcServer{
+					streamInterceptors: []grpc.StreamServerInterceptor{
+						func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
+							handler grpc.StreamHandler) error {
+							return nil
+						},
+					},
+				},
+				middlewares: ServerMiddlewaresConf{
+					Trace:   true,
+					Recover: true,
+					Breaker: true,
+				},
+			},
+			len: 4,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			assert.Equal(t, test.len, len(test.r.buildStreamInterceptors()))
+		})
+	}
+}