loghandler_test.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package handler
  2. import (
  3. "bytes"
  4. "io"
  5. "log"
  6. "net/http"
  7. "net/http/httptest"
  8. "testing"
  9. "time"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/zeromicro/go-zero/rest/internal"
  12. )
  13. func init() {
  14. log.SetOutput(io.Discard)
  15. }
  16. func TestLogHandler(t *testing.T) {
  17. handlers := []func(handler http.Handler) http.Handler{
  18. LogHandler,
  19. DetailedLogHandler,
  20. }
  21. for _, logHandler := range handlers {
  22. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  23. handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  24. r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything")
  25. w.Header().Set("X-Test", "test")
  26. w.WriteHeader(http.StatusServiceUnavailable)
  27. _, err := w.Write([]byte("content"))
  28. assert.Nil(t, err)
  29. flusher, ok := w.(http.Flusher)
  30. assert.True(t, ok)
  31. flusher.Flush()
  32. }))
  33. resp := httptest.NewRecorder()
  34. handler.ServeHTTP(resp, req)
  35. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  36. assert.Equal(t, "test", resp.Header().Get("X-Test"))
  37. assert.Equal(t, "content", resp.Body.String())
  38. }
  39. }
  40. func TestLogHandlerVeryLong(t *testing.T) {
  41. var buf bytes.Buffer
  42. for i := 0; i < limitBodyBytes<<1; i++ {
  43. buf.WriteByte('a')
  44. }
  45. req := httptest.NewRequest(http.MethodPost, "http://localhost", &buf)
  46. handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  47. r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything")
  48. io.Copy(io.Discard, r.Body)
  49. w.Header().Set("X-Test", "test")
  50. w.WriteHeader(http.StatusServiceUnavailable)
  51. _, err := w.Write([]byte("content"))
  52. assert.Nil(t, err)
  53. flusher, ok := w.(http.Flusher)
  54. assert.True(t, ok)
  55. flusher.Flush()
  56. }))
  57. resp := httptest.NewRecorder()
  58. handler.ServeHTTP(resp, req)
  59. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  60. assert.Equal(t, "test", resp.Header().Get("X-Test"))
  61. assert.Equal(t, "content", resp.Body.String())
  62. }
  63. func TestLogHandlerSlow(t *testing.T) {
  64. handlers := []func(handler http.Handler) http.Handler{
  65. LogHandler,
  66. DetailedLogHandler,
  67. }
  68. for _, logHandler := range handlers {
  69. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  70. handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  71. time.Sleep(defaultSlowThreshold + time.Millisecond*50)
  72. }))
  73. resp := httptest.NewRecorder()
  74. handler.ServeHTTP(resp, req)
  75. assert.Equal(t, http.StatusOK, resp.Code)
  76. }
  77. }
  78. func TestLogHandler_Hijack(t *testing.T) {
  79. resp := httptest.NewRecorder()
  80. writer := &loggedResponseWriter{
  81. w: resp,
  82. }
  83. assert.NotPanics(t, func() {
  84. writer.Hijack()
  85. })
  86. writer = &loggedResponseWriter{
  87. w: mockedHijackable{resp},
  88. }
  89. assert.NotPanics(t, func() {
  90. writer.Hijack()
  91. })
  92. }
  93. func TestDetailedLogHandler_Hijack(t *testing.T) {
  94. resp := httptest.NewRecorder()
  95. writer := &detailLoggedResponseWriter{
  96. writer: &loggedResponseWriter{
  97. w: resp,
  98. },
  99. }
  100. assert.NotPanics(t, func() {
  101. writer.Hijack()
  102. })
  103. writer = &detailLoggedResponseWriter{
  104. writer: &loggedResponseWriter{
  105. w: mockedHijackable{resp},
  106. },
  107. }
  108. assert.NotPanics(t, func() {
  109. writer.Hijack()
  110. })
  111. }
  112. func TestSetSlowThreshold(t *testing.T) {
  113. assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
  114. SetSlowThreshold(time.Second)
  115. assert.Equal(t, time.Second, slowThreshold.Load())
  116. }
  117. func TestWrapMethodWithColor(t *testing.T) {
  118. // no tty
  119. assert.Equal(t, http.MethodGet, wrapMethod(http.MethodGet))
  120. assert.Equal(t, http.MethodPost, wrapMethod(http.MethodPost))
  121. assert.Equal(t, http.MethodPut, wrapMethod(http.MethodPut))
  122. assert.Equal(t, http.MethodDelete, wrapMethod(http.MethodDelete))
  123. assert.Equal(t, http.MethodPatch, wrapMethod(http.MethodPatch))
  124. assert.Equal(t, http.MethodHead, wrapMethod(http.MethodHead))
  125. assert.Equal(t, http.MethodOptions, wrapMethod(http.MethodOptions))
  126. assert.Equal(t, http.MethodConnect, wrapMethod(http.MethodConnect))
  127. assert.Equal(t, http.MethodTrace, wrapMethod(http.MethodTrace))
  128. }
  129. func TestWrapStatusCodeWithColor(t *testing.T) {
  130. // no tty
  131. assert.Equal(t, "200", wrapStatusCode(http.StatusOK))
  132. assert.Equal(t, "302", wrapStatusCode(http.StatusFound))
  133. assert.Equal(t, "404", wrapStatusCode(http.StatusNotFound))
  134. assert.Equal(t, "500", wrapStatusCode(http.StatusInternalServerError))
  135. assert.Equal(t, "503", wrapStatusCode(http.StatusServiceUnavailable))
  136. }
  137. func BenchmarkLogHandler(b *testing.B) {
  138. b.ReportAllocs()
  139. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  140. handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  141. w.WriteHeader(http.StatusOK)
  142. }))
  143. for i := 0; i < b.N; i++ {
  144. resp := httptest.NewRecorder()
  145. handler.ServeHTTP(resp, req)
  146. }
  147. }