timeouthandler_test.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. package handler
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "net/http"
  7. "net/http/httptest"
  8. "strconv"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/wuntsong-org/go-zero-plus/core/logx/logtest"
  14. "github.com/wuntsong-org/go-zero-plus/rest/internal/response"
  15. )
  16. func TestTimeoutWriteFlushOutput(t *testing.T) {
  17. t.Run("flusher", func(t *testing.T) {
  18. timeoutHandler := TimeoutHandler(1000 * time.Millisecond)
  19. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  20. w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
  21. flusher, ok := w.(http.Flusher)
  22. if !ok {
  23. http.Error(w, "Flushing not supported", http.StatusInternalServerError)
  24. return
  25. }
  26. for i := 1; i <= 5; i++ {
  27. fmt.Fprint(w, strconv.Itoa(i)+" cats\n\n")
  28. flusher.Flush()
  29. time.Sleep(time.Millisecond)
  30. }
  31. }))
  32. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  33. resp := httptest.NewRecorder()
  34. handler.ServeHTTP(resp, req)
  35. scanner := bufio.NewScanner(resp.Body)
  36. var cats int
  37. for scanner.Scan() {
  38. line := scanner.Text()
  39. if strings.Contains(line, "cats") {
  40. cats++
  41. }
  42. }
  43. if err := scanner.Err(); err != nil {
  44. cats = 0
  45. }
  46. assert.Equal(t, 5, cats)
  47. })
  48. t.Run("writer", func(t *testing.T) {
  49. recorder := httptest.NewRecorder()
  50. timeoutHandler := TimeoutHandler(1000 * time.Millisecond)
  51. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  52. w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
  53. flusher, ok := w.(http.Flusher)
  54. if !ok {
  55. http.Error(w, "Flushing not supported", http.StatusInternalServerError)
  56. return
  57. }
  58. for i := 1; i <= 5; i++ {
  59. fmt.Fprint(w, strconv.Itoa(i)+" cats\n\n")
  60. flusher.Flush()
  61. time.Sleep(time.Millisecond)
  62. assert.Empty(t, recorder.Body.String())
  63. }
  64. }))
  65. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  66. resp := mockedResponseWriter{recorder}
  67. handler.ServeHTTP(resp, req)
  68. assert.Equal(t, "1 cats\n\n2 cats\n\n3 cats\n\n4 cats\n\n5 cats\n\n",
  69. recorder.Body.String())
  70. })
  71. }
  72. func TestTimeout(t *testing.T) {
  73. timeoutHandler := TimeoutHandler(time.Millisecond)
  74. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  75. time.Sleep(time.Minute)
  76. }))
  77. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  78. resp := httptest.NewRecorder()
  79. handler.ServeHTTP(resp, req)
  80. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  81. }
  82. func TestWithinTimeout(t *testing.T) {
  83. timeoutHandler := TimeoutHandler(time.Second)
  84. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  85. time.Sleep(time.Millisecond)
  86. }))
  87. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  88. resp := httptest.NewRecorder()
  89. handler.ServeHTTP(resp, req)
  90. assert.Equal(t, http.StatusOK, resp.Code)
  91. }
  92. func TestWithinTimeoutBadCode(t *testing.T) {
  93. timeoutHandler := TimeoutHandler(time.Second)
  94. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  95. w.WriteHeader(http.StatusInternalServerError)
  96. }))
  97. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  98. resp := httptest.NewRecorder()
  99. handler.ServeHTTP(resp, req)
  100. assert.Equal(t, http.StatusInternalServerError, resp.Code)
  101. }
  102. func TestWithTimeoutTimedout(t *testing.T) {
  103. timeoutHandler := TimeoutHandler(time.Millisecond)
  104. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  105. time.Sleep(time.Millisecond * 10)
  106. _, err := w.Write([]byte(`foo`))
  107. if err != nil {
  108. w.WriteHeader(http.StatusInternalServerError)
  109. return
  110. }
  111. w.WriteHeader(http.StatusOK)
  112. }))
  113. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  114. resp := httptest.NewRecorder()
  115. handler.ServeHTTP(resp, req)
  116. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  117. }
  118. func TestWithoutTimeout(t *testing.T) {
  119. timeoutHandler := TimeoutHandler(0)
  120. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  121. time.Sleep(100 * time.Millisecond)
  122. }))
  123. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  124. resp := httptest.NewRecorder()
  125. handler.ServeHTTP(resp, req)
  126. assert.Equal(t, http.StatusOK, resp.Code)
  127. }
  128. func TestTimeoutPanic(t *testing.T) {
  129. timeoutHandler := TimeoutHandler(time.Minute)
  130. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  131. panic("foo")
  132. }))
  133. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  134. resp := httptest.NewRecorder()
  135. assert.Panics(t, func() {
  136. handler.ServeHTTP(resp, req)
  137. })
  138. }
  139. func TestTimeoutWebsocket(t *testing.T) {
  140. timeoutHandler := TimeoutHandler(time.Millisecond)
  141. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  142. time.Sleep(time.Millisecond * 10)
  143. }))
  144. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  145. req.Header.Set(headerUpgrade, valueWebsocket)
  146. resp := httptest.NewRecorder()
  147. handler.ServeHTTP(resp, req)
  148. assert.Equal(t, http.StatusOK, resp.Code)
  149. }
  150. func TestTimeoutWroteHeaderTwice(t *testing.T) {
  151. timeoutHandler := TimeoutHandler(time.Minute)
  152. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  153. _, err := w.Write([]byte(`hello`))
  154. if err != nil {
  155. w.WriteHeader(http.StatusInternalServerError)
  156. return
  157. }
  158. w.Header().Set("foo", "bar")
  159. w.WriteHeader(http.StatusOK)
  160. }))
  161. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  162. resp := httptest.NewRecorder()
  163. handler.ServeHTTP(resp, req)
  164. assert.Equal(t, http.StatusOK, resp.Code)
  165. }
  166. func TestTimeoutWriteBadCode(t *testing.T) {
  167. timeoutHandler := TimeoutHandler(time.Minute)
  168. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  169. w.WriteHeader(1000)
  170. }))
  171. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  172. resp := httptest.NewRecorder()
  173. assert.Panics(t, func() {
  174. handler.ServeHTTP(resp, req)
  175. })
  176. }
  177. func TestTimeoutClientClosed(t *testing.T) {
  178. timeoutHandler := TimeoutHandler(time.Minute)
  179. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  180. w.WriteHeader(http.StatusServiceUnavailable)
  181. }))
  182. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  183. ctx, cancel := context.WithCancel(context.Background())
  184. req = req.WithContext(ctx)
  185. cancel()
  186. resp := httptest.NewRecorder()
  187. handler.ServeHTTP(resp, req)
  188. assert.Equal(t, statusClientClosedRequest, resp.Code)
  189. }
  190. func TestTimeoutHijack(t *testing.T) {
  191. resp := httptest.NewRecorder()
  192. writer := &timeoutWriter{
  193. w: response.NewWithCodeResponseWriter(resp),
  194. }
  195. assert.NotPanics(t, func() {
  196. _, _, _ = writer.Hijack()
  197. })
  198. writer = &timeoutWriter{
  199. w: response.NewWithCodeResponseWriter(mockedHijackable{resp}),
  200. }
  201. assert.NotPanics(t, func() {
  202. _, _, _ = writer.Hijack()
  203. })
  204. }
  205. func TestTimeoutFlush(t *testing.T) {
  206. timeoutHandler := TimeoutHandler(time.Minute)
  207. handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  208. flusher, ok := w.(http.Flusher)
  209. if !ok {
  210. http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
  211. return
  212. }
  213. flusher.Flush()
  214. }))
  215. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  216. resp := httptest.NewRecorder()
  217. handler.ServeHTTP(resp, req)
  218. assert.Equal(t, http.StatusOK, resp.Code)
  219. }
  220. func TestTimeoutPusher(t *testing.T) {
  221. handler := &timeoutWriter{
  222. w: mockedPusher{},
  223. }
  224. assert.Panics(t, func() {
  225. _ = handler.Push("any", nil)
  226. })
  227. handler = &timeoutWriter{
  228. w: httptest.NewRecorder(),
  229. }
  230. assert.Equal(t, http.ErrNotSupported, handler.Push("any", nil))
  231. }
  232. func TestTimeoutWriter_Hijack(t *testing.T) {
  233. writer := &timeoutWriter{
  234. w: httptest.NewRecorder(),
  235. h: make(http.Header),
  236. req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
  237. }
  238. _, _, err := writer.Hijack()
  239. assert.Error(t, err)
  240. }
  241. func TestTimeoutWroteTwice(t *testing.T) {
  242. c := logtest.NewCollector(t)
  243. writer := &timeoutWriter{
  244. w: response.NewWithCodeResponseWriter(httptest.NewRecorder()),
  245. h: make(http.Header),
  246. req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
  247. }
  248. writer.writeHeaderLocked(http.StatusOK)
  249. writer.writeHeaderLocked(http.StatusOK)
  250. assert.Contains(t, c.String(), "superfluous response.WriteHeader call")
  251. }
  252. type mockedPusher struct{}
  253. func (m mockedPusher) Header() http.Header {
  254. panic("implement me")
  255. }
  256. func (m mockedPusher) Write(_ []byte) (int, error) {
  257. panic("implement me")
  258. }
  259. func (m mockedPusher) WriteHeader(_ int) {
  260. panic("implement me")
  261. }
  262. func (m mockedPusher) Push(_ string, _ *http.PushOptions) error {
  263. panic("implement me")
  264. }
  265. type mockedResponseWriter struct {
  266. http.ResponseWriter
  267. }
  268. func (m mockedResponseWriter) Header() http.Header {
  269. return m.ResponseWriter.Header()
  270. }
  271. func (m mockedResponseWriter) Write(bytes []byte) (int, error) {
  272. return m.ResponseWriter.Write(bytes)
  273. }
  274. func (m mockedResponseWriter) WriteHeader(statusCode int) {
  275. m.ResponseWriter.WriteHeader(statusCode)
  276. }