statinterceptor_test.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "net"
  5. "testing"
  6. "time"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/tal-tech/go-zero/core/lang"
  9. "github.com/tal-tech/go-zero/core/stat"
  10. "google.golang.org/grpc"
  11. "google.golang.org/grpc/peer"
  12. )
  13. func TestUnaryStatInterceptor(t *testing.T) {
  14. metrics := stat.NewMetrics("mock")
  15. interceptor := UnaryStatInterceptor(metrics)
  16. _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  17. FullMethod: "/",
  18. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  19. return nil, nil
  20. })
  21. assert.Nil(t, err)
  22. }
  23. func TestUnaryStatInterceptor_crash(t *testing.T) {
  24. metrics := stat.NewMetrics("mock")
  25. interceptor := UnaryStatInterceptor(metrics)
  26. _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  27. FullMethod: "/",
  28. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  29. panic("error")
  30. })
  31. assert.NotNil(t, err)
  32. }
  33. func TestLogDuration(t *testing.T) {
  34. addrs, err := net.InterfaceAddrs()
  35. assert.Nil(t, err)
  36. assert.True(t, len(addrs) > 0)
  37. tests := []struct {
  38. name string
  39. ctx context.Context
  40. req interface{}
  41. duration time.Duration
  42. }{
  43. {
  44. name: "normal",
  45. ctx: context.Background(),
  46. req: "foo",
  47. },
  48. {
  49. name: "bad req",
  50. ctx: context.Background(),
  51. req: make(chan lang.PlaceholderType), // not marshalable
  52. },
  53. {
  54. name: "timeout",
  55. ctx: context.Background(),
  56. req: "foo",
  57. duration: time.Second,
  58. },
  59. {
  60. name: "timeout",
  61. ctx: peer.NewContext(context.Background(), &peer.Peer{
  62. Addr: addrs[0],
  63. }),
  64. req: "foo",
  65. },
  66. }
  67. for _, test := range tests {
  68. test := test
  69. t.Run(test.name, func(t *testing.T) {
  70. t.Parallel()
  71. assert.NotPanics(t, func() {
  72. logDuration(test.ctx, "foo", test.req, test.duration)
  73. })
  74. })
  75. }
  76. }