statinterceptor_test.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "net"
  5. "sync"
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. "github.com/zeromicro/go-zero/core/collection"
  10. "github.com/zeromicro/go-zero/core/lang"
  11. "github.com/zeromicro/go-zero/core/stat"
  12. "github.com/zeromicro/go-zero/core/syncx"
  13. "google.golang.org/grpc"
  14. "google.golang.org/grpc/peer"
  15. )
  16. func TestSetSlowThreshold(t *testing.T) {
  17. assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
  18. SetSlowThreshold(time.Second)
  19. // reset slowThreshold
  20. t.Cleanup(func() {
  21. slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
  22. })
  23. assert.Equal(t, time.Second, slowThreshold.Load())
  24. }
  25. func TestUnaryStatInterceptor(t *testing.T) {
  26. metrics := stat.NewMetrics("mock")
  27. interceptor := UnaryStatInterceptor(metrics, StatConf{})
  28. _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  29. FullMethod: "/",
  30. }, func(ctx context.Context, req any) (any, error) {
  31. return nil, nil
  32. })
  33. assert.Nil(t, err)
  34. }
  35. func TestLogDuration(t *testing.T) {
  36. addrs, err := net.InterfaceAddrs()
  37. assert.Nil(t, err)
  38. assert.True(t, len(addrs) > 0)
  39. tests := []struct {
  40. name string
  41. ctx context.Context
  42. req any
  43. duration time.Duration
  44. }{
  45. {
  46. name: "normal",
  47. ctx: context.Background(),
  48. req: "foo",
  49. },
  50. {
  51. name: "bad req",
  52. ctx: context.Background(),
  53. req: make(chan lang.PlaceholderType), // not marshalable
  54. },
  55. {
  56. name: "timeout",
  57. ctx: context.Background(),
  58. req: "foo",
  59. duration: time.Second,
  60. },
  61. {
  62. name: "timeout",
  63. ctx: peer.NewContext(context.Background(), &peer.Peer{
  64. Addr: addrs[0],
  65. }),
  66. req: "foo",
  67. },
  68. {
  69. name: "timeout",
  70. ctx: context.Background(),
  71. req: "foo",
  72. duration: slowThreshold.Load() + time.Second,
  73. },
  74. }
  75. for _, test := range tests {
  76. test := test
  77. t.Run(test.name, func(t *testing.T) {
  78. t.Parallel()
  79. assert.NotPanics(t, func() {
  80. logDuration(test.ctx, "foo", test.req, test.duration,
  81. collection.NewSet(), 0)
  82. })
  83. })
  84. }
  85. }
  86. func TestLogDurationWithoutContent(t *testing.T) {
  87. addrs, err := net.InterfaceAddrs()
  88. assert.Nil(t, err)
  89. assert.True(t, len(addrs) > 0)
  90. tests := []struct {
  91. name string
  92. ctx context.Context
  93. req any
  94. duration time.Duration
  95. }{
  96. {
  97. name: "normal",
  98. ctx: context.Background(),
  99. req: "foo",
  100. },
  101. {
  102. name: "bad req",
  103. ctx: context.Background(),
  104. req: make(chan lang.PlaceholderType), // not marshalable
  105. },
  106. {
  107. name: "timeout",
  108. ctx: context.Background(),
  109. req: "foo",
  110. duration: time.Second,
  111. },
  112. {
  113. name: "timeout",
  114. ctx: peer.NewContext(context.Background(), &peer.Peer{
  115. Addr: addrs[0],
  116. }),
  117. req: "foo",
  118. },
  119. {
  120. name: "timeout",
  121. ctx: context.Background(),
  122. req: "foo",
  123. duration: slowThreshold.Load() + time.Second,
  124. },
  125. }
  126. DontLogContentForMethod("foo")
  127. // reset ignoreContentMethods
  128. t.Cleanup(func() {
  129. ignoreContentMethods = sync.Map{}
  130. })
  131. for _, test := range tests {
  132. test := test
  133. t.Run(test.name, func(t *testing.T) {
  134. t.Parallel()
  135. assert.NotPanics(t, func() {
  136. logDuration(test.ctx, "foo", test.req, test.duration,
  137. collection.NewSet(), 0)
  138. })
  139. })
  140. }
  141. }
  142. func Test_shouldLogContent(t *testing.T) {
  143. type args struct {
  144. method string
  145. staticNotLoggingContentMethods []string
  146. }
  147. tests := []struct {
  148. name string
  149. args args
  150. want bool
  151. setup func()
  152. }{
  153. {
  154. "empty",
  155. args{
  156. method: "foo",
  157. },
  158. true,
  159. nil,
  160. },
  161. {
  162. "static",
  163. args{
  164. method: "foo",
  165. staticNotLoggingContentMethods: []string{"foo"},
  166. },
  167. false,
  168. nil,
  169. },
  170. {
  171. "dynamic",
  172. args{
  173. method: "foo",
  174. },
  175. false,
  176. func() {
  177. DontLogContentForMethod("foo")
  178. },
  179. },
  180. }
  181. for _, tt := range tests {
  182. t.Run(tt.name, func(t *testing.T) {
  183. if tt.setup != nil {
  184. tt.setup()
  185. }
  186. // reset ignoreContentMethods
  187. t.Cleanup(func() {
  188. ignoreContentMethods = sync.Map{}
  189. })
  190. set := collection.NewSet()
  191. set.AddStr(tt.args.staticNotLoggingContentMethods...)
  192. assert.Equalf(t, tt.want, shouldLogContent(tt.args.method, set), "shouldLogContent(%v, %v)", tt.args.method, tt.args.staticNotLoggingContentMethods)
  193. })
  194. }
  195. }
  196. func Test_isSlow(t *testing.T) {
  197. type args struct {
  198. duration time.Duration
  199. staticSlowThreshold time.Duration
  200. }
  201. tests := []struct {
  202. name string
  203. args args
  204. want bool
  205. setup func()
  206. }{
  207. {
  208. "default",
  209. args{
  210. duration: time.Millisecond * 501,
  211. },
  212. true,
  213. nil,
  214. },
  215. {
  216. "static",
  217. args{
  218. duration: time.Millisecond * 200,
  219. staticSlowThreshold: time.Millisecond * 100,
  220. },
  221. true,
  222. nil,
  223. },
  224. {
  225. "dynamic",
  226. args{
  227. duration: time.Millisecond * 200,
  228. },
  229. true,
  230. func() {
  231. SetSlowThreshold(time.Millisecond * 100)
  232. },
  233. },
  234. }
  235. for _, tt := range tests {
  236. t.Run(tt.name, func(t *testing.T) {
  237. if tt.setup != nil {
  238. tt.setup()
  239. }
  240. // reset slowThreshold
  241. t.Cleanup(func() {
  242. slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
  243. })
  244. assert.Equalf(t, tt.want, isSlow(tt.args.duration, tt.args.staticSlowThreshold), "isSlow(%v, %v)", tt.args.duration, tt.args.staticSlowThreshold)
  245. })
  246. }
  247. }