timeouthandler_test.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. package handler
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "net/http"
  7. "net/http/httptest"
  8. "strconv"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/zeromicro/go-zero/core/logx/logtest"
  14. "github.com/zeromicro/go-zero/rest/internal/response"
  15. )
  16. func TestTimeoutWriteFlushOutput(t *testing.T) {
  17. timeoutHandler := TimeoutHandler(1000 * time.Millisecond)
  18. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  19. w.Header().Set("Content-Type", "text/event-stream;charset=utf-8")
  20. flusher, ok := w.(http.Flusher)
  21. if !ok {
  22. http.Error(w, "Flushing not supported", http.StatusInternalServerError)
  23. return
  24. }
  25. for i := 1; i <= 5; i++ {
  26. fmt.Fprint(w, strconv.Itoa(i)+"只猫猫\n\n")
  27. flusher.Flush()
  28. time.Sleep(time.Millisecond)
  29. }
  30. }))
  31. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  32. resp := httptest.NewRecorder()
  33. handler.ServeHTTP(resp, req)
  34. scanner := bufio.NewScanner(resp.Body)
  35. mao := 0
  36. for scanner.Scan() {
  37. line := scanner.Text()
  38. if strings.Contains(line, "猫猫") {
  39. mao++
  40. }
  41. }
  42. if err := scanner.Err(); err != nil {
  43. mao = 0
  44. }
  45. assert.Equal(t, "5只猫猫", strconv.Itoa(mao)+"只猫猫")
  46. }
  47. func TestTimeout(t *testing.T) {
  48. timeoutHandler := TimeoutHandler(time.Millisecond)
  49. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  50. time.Sleep(time.Minute)
  51. }))
  52. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  53. resp := httptest.NewRecorder()
  54. handler.ServeHTTP(resp, req)
  55. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  56. }
  57. func TestWithinTimeout(t *testing.T) {
  58. timeoutHandler := TimeoutHandler(time.Second)
  59. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  60. time.Sleep(time.Millisecond)
  61. }))
  62. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  63. resp := httptest.NewRecorder()
  64. handler.ServeHTTP(resp, req)
  65. assert.Equal(t, http.StatusOK, resp.Code)
  66. }
  67. func TestWithTimeoutTimedout(t *testing.T) {
  68. timeoutHandler := TimeoutHandler(time.Millisecond)
  69. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  70. time.Sleep(time.Millisecond * 10)
  71. _, err := w.Write([]byte(`foo`))
  72. if err != nil {
  73. w.WriteHeader(http.StatusInternalServerError)
  74. return
  75. }
  76. w.WriteHeader(http.StatusOK)
  77. }))
  78. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  79. resp := httptest.NewRecorder()
  80. handler.ServeHTTP(resp, req)
  81. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  82. }
  83. func TestWithoutTimeout(t *testing.T) {
  84. timeoutHandler := TimeoutHandler(0)
  85. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  86. time.Sleep(100 * time.Millisecond)
  87. }))
  88. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  89. resp := httptest.NewRecorder()
  90. handler.ServeHTTP(resp, req)
  91. assert.Equal(t, http.StatusOK, resp.Code)
  92. }
  93. func TestTimeoutPanic(t *testing.T) {
  94. timeoutHandler := TimeoutHandler(time.Minute)
  95. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  96. panic("foo")
  97. }))
  98. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  99. resp := httptest.NewRecorder()
  100. assert.Panics(t, func() {
  101. handler.ServeHTTP(resp, req)
  102. })
  103. }
  104. func TestTimeoutWebsocket(t *testing.T) {
  105. timeoutHandler := TimeoutHandler(time.Millisecond)
  106. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  107. time.Sleep(time.Millisecond * 10)
  108. }))
  109. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  110. req.Header.Set(headerUpgrade, valueWebsocket)
  111. resp := httptest.NewRecorder()
  112. handler.ServeHTTP(resp, req)
  113. assert.Equal(t, http.StatusOK, resp.Code)
  114. }
  115. func TestTimeoutWroteHeaderTwice(t *testing.T) {
  116. timeoutHandler := TimeoutHandler(time.Minute)
  117. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  118. _, err := w.Write([]byte(`hello`))
  119. if err != nil {
  120. w.WriteHeader(http.StatusInternalServerError)
  121. return
  122. }
  123. w.Header().Set("foo", "bar")
  124. w.WriteHeader(http.StatusOK)
  125. }))
  126. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  127. resp := httptest.NewRecorder()
  128. handler.ServeHTTP(resp, req)
  129. assert.Equal(t, http.StatusOK, resp.Code)
  130. }
  131. func TestTimeoutWriteBadCode(t *testing.T) {
  132. timeoutHandler := TimeoutHandler(time.Minute)
  133. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  134. w.WriteHeader(1000)
  135. }))
  136. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  137. resp := httptest.NewRecorder()
  138. assert.Panics(t, func() {
  139. handler.ServeHTTP(resp, req)
  140. })
  141. }
  142. func TestTimeoutClientClosed(t *testing.T) {
  143. timeoutHandler := TimeoutHandler(time.Minute)
  144. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  145. w.WriteHeader(http.StatusServiceUnavailable)
  146. }))
  147. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  148. ctx, cancel := context.WithCancel(context.Background())
  149. req = req.WithContext(ctx)
  150. cancel()
  151. resp := httptest.NewRecorder()
  152. handler.ServeHTTP(resp, req)
  153. assert.Equal(t, statusClientClosedRequest, resp.Code)
  154. }
  155. func TestTimeoutHijack(t *testing.T) {
  156. resp := httptest.NewRecorder()
  157. writer := &timeoutWriter{
  158. w: &response.WithCodeResponseWriter{
  159. Writer: resp,
  160. },
  161. }
  162. assert.NotPanics(t, func() {
  163. _, _, _ = writer.Hijack()
  164. })
  165. writer = &timeoutWriter{
  166. w: &response.WithCodeResponseWriter{
  167. Writer: mockedHijackable{resp},
  168. },
  169. }
  170. assert.NotPanics(t, func() {
  171. _, _, _ = writer.Hijack()
  172. })
  173. }
  174. func TestTimeoutFlush(t *testing.T) {
  175. timeoutHandler := TimeoutHandler(time.Minute)
  176. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  177. flusher, ok := w.(http.Flusher)
  178. if !ok {
  179. http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
  180. return
  181. }
  182. flusher.Flush()
  183. }))
  184. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  185. resp := httptest.NewRecorder()
  186. handler.ServeHTTP(resp, req)
  187. assert.Equal(t, http.StatusOK, resp.Code)
  188. }
  189. func TestTimeoutPusher(t *testing.T) {
  190. handler := &timeoutWriter{
  191. w: mockedPusher{},
  192. }
  193. assert.Panics(t, func() {
  194. _ = handler.Push("any", nil)
  195. })
  196. handler = &timeoutWriter{
  197. w: httptest.NewRecorder(),
  198. }
  199. assert.Equal(t, http.ErrNotSupported, handler.Push("any", nil))
  200. }
  201. func TestTimeoutWriter_Hijack(t *testing.T) {
  202. writer := &timeoutWriter{
  203. w: httptest.NewRecorder(),
  204. h: make(http.Header),
  205. req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
  206. }
  207. _, _, err := writer.Hijack()
  208. assert.Error(t, err)
  209. }
  210. func TestTimeoutWroteTwice(t *testing.T) {
  211. c := logtest.NewCollector(t)
  212. writer := &timeoutWriter{
  213. w: &response.WithCodeResponseWriter{
  214. Writer: httptest.NewRecorder(),
  215. },
  216. h: make(http.Header),
  217. req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
  218. }
  219. writer.writeHeaderLocked(http.StatusOK)
  220. writer.writeHeaderLocked(http.StatusOK)
  221. assert.Contains(t, c.String(), "superfluous response.WriteHeader call")
  222. }
  223. type mockedPusher struct{}
  224. func (m mockedPusher) Header() http.Header {
  225. panic("implement me")
  226. }
  227. func (m mockedPusher) Write(_ []byte) (int, error) {
  228. panic("implement me")
  229. }
  230. func (m mockedPusher) WriteHeader(_ int) {
  231. panic("implement me")
  232. }
  233. func (m mockedPusher) Push(_ string, _ *http.PushOptions) error {
  234. panic("implement me")
  235. }