loghandler_test.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package handler
  2. import (
  3. "bytes"
  4. "errors"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "testing"
  9. "time"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/wuntsong-org/go-zero-plus/rest/internal"
  12. "github.com/wuntsong-org/go-zero-plus/rest/internal/response"
  13. )
  14. func TestLogHandler(t *testing.T) {
  15. handlers := []func(handler http.Handler) http.Handler{
  16. LogHandler,
  17. DetailedLogHandler,
  18. }
  19. for _, logHandler := range handlers {
  20. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  21. handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  22. internal.LogCollectorFromContext(r.Context()).Append("anything")
  23. w.Header().Set("X-Test", "test")
  24. w.WriteHeader(http.StatusServiceUnavailable)
  25. _, err := w.Write([]byte("content"))
  26. assert.Nil(t, err)
  27. flusher, ok := w.(http.Flusher)
  28. assert.True(t, ok)
  29. flusher.Flush()
  30. }))
  31. resp := httptest.NewRecorder()
  32. handler.ServeHTTP(resp, req)
  33. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  34. assert.Equal(t, "test", resp.Header().Get("X-Test"))
  35. assert.Equal(t, "content", resp.Body.String())
  36. }
  37. }
  38. func TestLogHandlerVeryLong(t *testing.T) {
  39. var buf bytes.Buffer
  40. for i := 0; i < limitBodyBytes<<1; i++ {
  41. buf.WriteByte('a')
  42. }
  43. req := httptest.NewRequest(http.MethodPost, "http://localhost", &buf)
  44. handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  45. internal.LogCollectorFromContext(r.Context()).Append("anything")
  46. _, _ = io.Copy(io.Discard, r.Body)
  47. w.Header().Set("X-Test", "test")
  48. w.WriteHeader(http.StatusServiceUnavailable)
  49. _, err := w.Write([]byte("content"))
  50. assert.Nil(t, err)
  51. flusher, ok := w.(http.Flusher)
  52. assert.True(t, ok)
  53. flusher.Flush()
  54. }))
  55. resp := httptest.NewRecorder()
  56. handler.ServeHTTP(resp, req)
  57. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  58. assert.Equal(t, "test", resp.Header().Get("X-Test"))
  59. assert.Equal(t, "content", resp.Body.String())
  60. }
  61. func TestLogHandlerSlow(t *testing.T) {
  62. handlers := []func(handler http.Handler) http.Handler{
  63. LogHandler,
  64. DetailedLogHandler,
  65. }
  66. for _, logHandler := range handlers {
  67. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  68. handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  69. time.Sleep(defaultSlowThreshold + time.Millisecond*50)
  70. }))
  71. resp := httptest.NewRecorder()
  72. handler.ServeHTTP(resp, req)
  73. assert.Equal(t, http.StatusOK, resp.Code)
  74. }
  75. }
  76. func TestDetailedLogHandler_Hijack(t *testing.T) {
  77. resp := httptest.NewRecorder()
  78. writer := &detailLoggedResponseWriter{
  79. writer: response.NewWithCodeResponseWriter(resp),
  80. }
  81. assert.NotPanics(t, func() {
  82. _, _, _ = writer.Hijack()
  83. })
  84. writer = &detailLoggedResponseWriter{
  85. writer: response.NewWithCodeResponseWriter(resp),
  86. }
  87. assert.NotPanics(t, func() {
  88. _, _, _ = writer.Hijack()
  89. })
  90. writer = &detailLoggedResponseWriter{
  91. writer: response.NewWithCodeResponseWriter(mockedHijackable{
  92. ResponseRecorder: resp,
  93. }),
  94. }
  95. assert.NotPanics(t, func() {
  96. _, _, _ = writer.Hijack()
  97. })
  98. }
  99. func TestSetSlowThreshold(t *testing.T) {
  100. assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
  101. SetSlowThreshold(time.Second)
  102. assert.Equal(t, time.Second, slowThreshold.Load())
  103. }
  104. func TestWrapMethodWithColor(t *testing.T) {
  105. // no tty
  106. assert.Equal(t, http.MethodGet, wrapMethod(http.MethodGet))
  107. assert.Equal(t, http.MethodPost, wrapMethod(http.MethodPost))
  108. assert.Equal(t, http.MethodPut, wrapMethod(http.MethodPut))
  109. assert.Equal(t, http.MethodDelete, wrapMethod(http.MethodDelete))
  110. assert.Equal(t, http.MethodPatch, wrapMethod(http.MethodPatch))
  111. assert.Equal(t, http.MethodHead, wrapMethod(http.MethodHead))
  112. assert.Equal(t, http.MethodOptions, wrapMethod(http.MethodOptions))
  113. assert.Equal(t, http.MethodConnect, wrapMethod(http.MethodConnect))
  114. assert.Equal(t, http.MethodTrace, wrapMethod(http.MethodTrace))
  115. }
  116. func TestWrapStatusCodeWithColor(t *testing.T) {
  117. // no tty
  118. assert.Equal(t, "200", wrapStatusCode(http.StatusOK))
  119. assert.Equal(t, "302", wrapStatusCode(http.StatusFound))
  120. assert.Equal(t, "404", wrapStatusCode(http.StatusNotFound))
  121. assert.Equal(t, "500", wrapStatusCode(http.StatusInternalServerError))
  122. assert.Equal(t, "503", wrapStatusCode(http.StatusServiceUnavailable))
  123. }
  124. func TestDumpRequest(t *testing.T) {
  125. const errMsg = "error"
  126. r := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  127. r.Body = mockedReadCloser{errMsg: errMsg}
  128. assert.Equal(t, errMsg, dumpRequest(r))
  129. }
  130. func BenchmarkLogHandler(b *testing.B) {
  131. b.ReportAllocs()
  132. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  133. handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  134. w.WriteHeader(http.StatusOK)
  135. }))
  136. for i := 0; i < b.N; i++ {
  137. resp := httptest.NewRecorder()
  138. handler.ServeHTTP(resp, req)
  139. }
  140. }
  141. type mockedReadCloser struct {
  142. errMsg string
  143. }
  144. func (m mockedReadCloser) Read(p []byte) (n int, err error) {
  145. return 0, errors.New(m.errMsg)
  146. }
  147. func (m mockedReadCloser) Close() error {
  148. return nil
  149. }