timeouthandler.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. package handler
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "net/http"
  11. "path"
  12. "runtime"
  13. "strings"
  14. "sync"
  15. "time"
  16. "github.com/zeromicro/go-zero/rest/httpx"
  17. "github.com/zeromicro/go-zero/rest/internal"
  18. )
  19. const (
  20. statusClientClosedRequest = 499
  21. reason = "Request Timeout"
  22. headerUpgrade = "Upgrade"
  23. valueWebsocket = "websocket"
  24. )
  25. // TimeoutHandler returns the handler with given timeout.
  26. // If client closed request, code 499 will be logged.
  27. // Notice: even if canceled in server side, 499 will be logged as well.
  28. func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
  29. return func(next http.Handler) http.Handler {
  30. if duration <= 0 {
  31. return next
  32. }
  33. return &timeoutHandler{
  34. handler: next,
  35. dt: duration,
  36. }
  37. }
  38. }
  39. // timeoutHandler is the handler that controls the request timeout.
  40. // Why we implement it on our own, because the stdlib implementation
  41. // treats the ClientClosedRequest as http.StatusServiceUnavailable.
  42. // And we write the codes in logs as code 499, which is defined by nginx.
  43. type timeoutHandler struct {
  44. handler http.Handler
  45. dt time.Duration
  46. }
  47. func (h *timeoutHandler) errorBody() string {
  48. return reason
  49. }
  50. func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  51. if r.Header.Get(headerUpgrade) == valueWebsocket {
  52. h.handler.ServeHTTP(w, r)
  53. return
  54. }
  55. ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt)
  56. defer cancelCtx()
  57. r = r.WithContext(ctx)
  58. done := make(chan struct{})
  59. tw := &timeoutWriter{
  60. w: w,
  61. h: make(http.Header),
  62. req: r,
  63. }
  64. panicChan := make(chan any, 1)
  65. go func() {
  66. defer func() {
  67. if p := recover(); p != nil {
  68. panicChan <- p
  69. }
  70. }()
  71. h.handler.ServeHTTP(tw, r)
  72. close(done)
  73. }()
  74. select {
  75. case p := <-panicChan:
  76. panic(p)
  77. case <-done:
  78. tw.mu.Lock()
  79. defer tw.mu.Unlock()
  80. dst := w.Header()
  81. for k, vv := range tw.h {
  82. dst[k] = vv
  83. }
  84. if !tw.wroteHeader {
  85. tw.code = http.StatusOK
  86. }
  87. w.WriteHeader(tw.code)
  88. w.Write(tw.wbuf.Bytes())
  89. case <-ctx.Done():
  90. tw.mu.Lock()
  91. defer tw.mu.Unlock()
  92. // there isn't any user-defined middleware before TimoutHandler,
  93. // so we can guarantee that cancelation in biz related code won't come here.
  94. httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
  95. if errors.Is(err, context.Canceled) {
  96. w.WriteHeader(statusClientClosedRequest)
  97. } else {
  98. w.WriteHeader(http.StatusServiceUnavailable)
  99. }
  100. io.WriteString(w, h.errorBody())
  101. })
  102. tw.timedOut = true
  103. }
  104. }
  105. type timeoutWriter struct {
  106. w http.ResponseWriter
  107. h http.Header
  108. wbuf bytes.Buffer
  109. req *http.Request
  110. mu sync.Mutex
  111. timedOut bool
  112. wroteHeader bool
  113. code int
  114. }
  115. var _ http.Pusher = (*timeoutWriter)(nil)
  116. func (tw *timeoutWriter) Flush() {
  117. if flusher, ok := tw.w.(http.Flusher); ok {
  118. flusher.Flush()
  119. }
  120. }
  121. func (tw *timeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  122. if hijacked, ok := tw.w.(http.Hijacker); ok {
  123. return hijacked.Hijack()
  124. }
  125. return nil, nil, errors.New("server doesn't support hijacking")
  126. }
  127. // Header returns the underline temporary http.Header.
  128. func (tw *timeoutWriter) Header() http.Header { return tw.h }
  129. // Push implements the Pusher interface.
  130. func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error {
  131. if pusher, ok := tw.w.(http.Pusher); ok {
  132. return pusher.Push(target, opts)
  133. }
  134. return http.ErrNotSupported
  135. }
  136. // Write writes the data to the connection as part of an HTTP reply.
  137. // Timeout and multiple header written are guarded.
  138. func (tw *timeoutWriter) Write(p []byte) (int, error) {
  139. tw.mu.Lock()
  140. defer tw.mu.Unlock()
  141. if tw.timedOut {
  142. return 0, http.ErrHandlerTimeout
  143. }
  144. if !tw.wroteHeader {
  145. tw.writeHeaderLocked(http.StatusOK)
  146. }
  147. return tw.wbuf.Write(p)
  148. }
  149. func (tw *timeoutWriter) writeHeaderLocked(code int) {
  150. checkWriteHeaderCode(code)
  151. switch {
  152. case tw.timedOut:
  153. return
  154. case tw.wroteHeader:
  155. if tw.req != nil {
  156. caller := relevantCaller()
  157. internal.Errorf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)",
  158. caller.Function, path.Base(caller.File), caller.Line)
  159. }
  160. default:
  161. tw.wroteHeader = true
  162. tw.code = code
  163. }
  164. }
  165. func (tw *timeoutWriter) WriteHeader(code int) {
  166. tw.mu.Lock()
  167. defer tw.mu.Unlock()
  168. if !tw.wroteHeader {
  169. tw.writeHeaderLocked(code)
  170. }
  171. }
  172. func checkWriteHeaderCode(code int) {
  173. if code < 100 || code > 599 {
  174. panic(fmt.Sprintf("invalid WriteHeader code %v", code))
  175. }
  176. }
  177. // relevantCaller searches the call stack for the first function outside of net/http.
  178. // The purpose of this function is to provide more helpful error messages.
  179. func relevantCaller() runtime.Frame {
  180. pc := make([]uintptr, 16)
  181. n := runtime.Callers(1, pc)
  182. frames := runtime.CallersFrames(pc[:n])
  183. var frame runtime.Frame
  184. for {
  185. frame, more := frames.Next()
  186. if !strings.HasPrefix(frame.Function, "net/http.") {
  187. return frame
  188. }
  189. if !more {
  190. break
  191. }
  192. }
  193. return frame
  194. }