rpcserver_test.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package internal
  2. import (
  3. "context"
  4. "sync"
  5. "testing"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/zeromicro/go-zero/core/proc"
  8. "github.com/zeromicro/go-zero/core/stat"
  9. "github.com/zeromicro/go-zero/zrpc/internal/mock"
  10. "google.golang.org/grpc"
  11. )
  12. func TestRpcServer(t *testing.T) {
  13. metrics := stat.NewMetrics("foo")
  14. server := NewRpcServer("localhost:54321", ServerMiddlewaresConf{
  15. Trace: true,
  16. Recover: true,
  17. Stat: true,
  18. Prometheus: true,
  19. Breaker: true,
  20. }, WithMetrics(metrics), WithRpcHealth(true))
  21. server.SetName("mock")
  22. var wg, wgDone sync.WaitGroup
  23. var grpcServer *grpc.Server
  24. var lock sync.Mutex
  25. wg.Add(1)
  26. wgDone.Add(1)
  27. go func() {
  28. err := server.Start(func(server *grpc.Server) {
  29. lock.Lock()
  30. mock.RegisterDepositServiceServer(server, new(mock.DepositServer))
  31. grpcServer = server
  32. lock.Unlock()
  33. wg.Done()
  34. })
  35. assert.Nil(t, err)
  36. wgDone.Done()
  37. }()
  38. wg.Wait()
  39. proc.WrapUp()
  40. lock.Lock()
  41. grpcServer.GracefulStop()
  42. lock.Unlock()
  43. proc.WrapUp()
  44. wgDone.Wait()
  45. }
  46. func TestRpcServer_WithBadAddress(t *testing.T) {
  47. server := NewRpcServer("localhost:111111", ServerMiddlewaresConf{
  48. Trace: true,
  49. Recover: true,
  50. Stat: true,
  51. Prometheus: true,
  52. Breaker: true,
  53. }, WithRpcHealth(true))
  54. server.SetName("mock")
  55. err := server.Start(func(server *grpc.Server) {
  56. mock.RegisterDepositServiceServer(server, new(mock.DepositServer))
  57. })
  58. assert.NotNil(t, err)
  59. proc.WrapUp()
  60. }
  61. func TestRpcServer_buildUnaryInterceptor(t *testing.T) {
  62. tests := []struct {
  63. name string
  64. r *rpcServer
  65. len int
  66. }{
  67. {
  68. name: "empty",
  69. r: &rpcServer{
  70. baseRpcServer: &baseRpcServer{},
  71. },
  72. len: 0,
  73. },
  74. {
  75. name: "custom",
  76. r: &rpcServer{
  77. baseRpcServer: &baseRpcServer{
  78. unaryInterceptors: []grpc.UnaryServerInterceptor{
  79. func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
  80. handler grpc.UnaryHandler) (interface{}, error) {
  81. return nil, nil
  82. },
  83. },
  84. },
  85. },
  86. len: 1,
  87. },
  88. {
  89. name: "middleware",
  90. r: &rpcServer{
  91. baseRpcServer: &baseRpcServer{
  92. unaryInterceptors: []grpc.UnaryServerInterceptor{
  93. func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
  94. handler grpc.UnaryHandler) (interface{}, error) {
  95. return nil, nil
  96. },
  97. },
  98. },
  99. middlewares: ServerMiddlewaresConf{
  100. Trace: true,
  101. Recover: true,
  102. Stat: true,
  103. Prometheus: true,
  104. Breaker: true,
  105. },
  106. },
  107. len: 6,
  108. },
  109. }
  110. for _, test := range tests {
  111. t.Run(test.name, func(t *testing.T) {
  112. assert.Equal(t, test.len, len(test.r.buildUnaryInterceptors()))
  113. })
  114. }
  115. }
  116. func TestRpcServer_buildStreamInterceptor(t *testing.T) {
  117. tests := []struct {
  118. name string
  119. r *rpcServer
  120. len int
  121. }{
  122. {
  123. name: "empty",
  124. r: &rpcServer{
  125. baseRpcServer: &baseRpcServer{},
  126. },
  127. len: 0,
  128. },
  129. {
  130. name: "custom",
  131. r: &rpcServer{
  132. baseRpcServer: &baseRpcServer{
  133. streamInterceptors: []grpc.StreamServerInterceptor{
  134. func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
  135. handler grpc.StreamHandler) error {
  136. return nil
  137. },
  138. },
  139. },
  140. },
  141. len: 1,
  142. },
  143. {
  144. name: "middleware",
  145. r: &rpcServer{
  146. baseRpcServer: &baseRpcServer{
  147. streamInterceptors: []grpc.StreamServerInterceptor{
  148. func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
  149. handler grpc.StreamHandler) error {
  150. return nil
  151. },
  152. },
  153. },
  154. middlewares: ServerMiddlewaresConf{
  155. Trace: true,
  156. Recover: true,
  157. Breaker: true,
  158. },
  159. },
  160. len: 4,
  161. },
  162. }
  163. for _, test := range tests {
  164. t.Run(test.name, func(t *testing.T) {
  165. assert.Equal(t, test.len, len(test.r.buildStreamInterceptors()))
  166. })
  167. }
  168. }