timeouthandler.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. package handler
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "net/http"
  11. "path"
  12. "runtime"
  13. "strings"
  14. "sync"
  15. "time"
  16. "github.com/zeromicro/go-zero/rest/httpx"
  17. "github.com/zeromicro/go-zero/rest/internal"
  18. )
  19. const (
  20. statusClientClosedRequest = 499
  21. reason = "Request Timeout"
  22. headerUpgrade = "Upgrade"
  23. valueWebsocket = "websocket"
  24. )
  25. // TimeoutHandler returns the handler with given timeout.
  26. // If client closed request, code 499 will be logged.
  27. // Notice: even if canceled in server side, 499 will be logged as well.
  28. func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
  29. return func(next http.Handler) http.Handler {
  30. if duration > 0 {
  31. return &timeoutHandler{
  32. handler: next,
  33. dt: duration,
  34. }
  35. }
  36. return next
  37. }
  38. }
  39. // timeoutHandler is the handler that controls the request timeout.
  40. // Why we implement it on our own, because the stdlib implementation
  41. // treats the ClientClosedRequest as http.StatusServiceUnavailable.
  42. // And we write the codes in logs as code 499, which is defined by nginx.
  43. type timeoutHandler struct {
  44. handler http.Handler
  45. dt time.Duration
  46. }
  47. func (h *timeoutHandler) errorBody() string {
  48. return reason
  49. }
  50. func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  51. if r.Header.Get(headerUpgrade) == valueWebsocket {
  52. h.handler.ServeHTTP(w, r)
  53. return
  54. }
  55. ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt)
  56. defer cancelCtx()
  57. r = r.WithContext(ctx)
  58. done := make(chan struct{})
  59. tw := &timeoutWriter{
  60. w: w,
  61. h: make(http.Header),
  62. req: r,
  63. }
  64. panicChan := make(chan any, 1)
  65. go func() {
  66. defer func() {
  67. if p := recover(); p != nil {
  68. panicChan <- p
  69. }
  70. }()
  71. h.handler.ServeHTTP(tw, r)
  72. close(done)
  73. }()
  74. select {
  75. case p := <-panicChan:
  76. panic(p)
  77. case <-done:
  78. tw.mu.Lock()
  79. defer tw.mu.Unlock()
  80. dst := w.Header()
  81. for k, vv := range tw.h {
  82. dst[k] = vv
  83. }
  84. if !tw.wroteHeader {
  85. tw.code = http.StatusOK
  86. }
  87. w.WriteHeader(tw.code)
  88. w.Write(tw.wbuf.Bytes())
  89. case <-ctx.Done():
  90. tw.mu.Lock()
  91. defer tw.mu.Unlock()
  92. // there isn't any user-defined middleware before TimoutHandler,
  93. // so we can guarantee that cancelation in biz related code won't come here.
  94. httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
  95. if errors.Is(err, context.Canceled) {
  96. w.WriteHeader(statusClientClosedRequest)
  97. } else {
  98. w.WriteHeader(http.StatusServiceUnavailable)
  99. }
  100. io.WriteString(w, h.errorBody())
  101. })
  102. tw.timedOut = true
  103. }
  104. }
  105. type timeoutWriter struct {
  106. w http.ResponseWriter
  107. h http.Header
  108. wbuf bytes.Buffer
  109. req *http.Request
  110. mu sync.Mutex
  111. timedOut bool
  112. wroteHeader bool
  113. code int
  114. }
  115. var _ http.Pusher = (*timeoutWriter)(nil)
  116. func (tw *timeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  117. if hijacked, ok := tw.w.(http.Hijacker); ok {
  118. return hijacked.Hijack()
  119. }
  120. return nil, nil, errors.New("server doesn't support hijacking")
  121. }
  122. // Header returns the underline temporary http.Header.
  123. func (tw *timeoutWriter) Header() http.Header { return tw.h }
  124. // Push implements the Pusher interface.
  125. func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error {
  126. if pusher, ok := tw.w.(http.Pusher); ok {
  127. return pusher.Push(target, opts)
  128. }
  129. return http.ErrNotSupported
  130. }
  131. // Write writes the data to the connection as part of an HTTP reply.
  132. // Timeout and multiple header written are guarded.
  133. func (tw *timeoutWriter) Write(p []byte) (int, error) {
  134. tw.mu.Lock()
  135. defer tw.mu.Unlock()
  136. if tw.timedOut {
  137. return 0, http.ErrHandlerTimeout
  138. }
  139. if !tw.wroteHeader {
  140. tw.writeHeaderLocked(http.StatusOK)
  141. }
  142. return tw.wbuf.Write(p)
  143. }
  144. func (tw *timeoutWriter) writeHeaderLocked(code int) {
  145. checkWriteHeaderCode(code)
  146. switch {
  147. case tw.timedOut:
  148. return
  149. case tw.wroteHeader:
  150. if tw.req != nil {
  151. caller := relevantCaller()
  152. internal.Errorf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)",
  153. caller.Function, path.Base(caller.File), caller.Line)
  154. }
  155. default:
  156. tw.wroteHeader = true
  157. tw.code = code
  158. }
  159. }
  160. func (tw *timeoutWriter) WriteHeader(code int) {
  161. tw.mu.Lock()
  162. defer tw.mu.Unlock()
  163. tw.writeHeaderLocked(code)
  164. }
  165. func checkWriteHeaderCode(code int) {
  166. if code < 100 || code > 599 {
  167. panic(fmt.Sprintf("invalid WriteHeader code %v", code))
  168. }
  169. }
  170. // relevantCaller searches the call stack for the first function outside of net/http.
  171. // The purpose of this function is to provide more helpful error messages.
  172. func relevantCaller() runtime.Frame {
  173. pc := make([]uintptr, 16)
  174. n := runtime.Callers(1, pc)
  175. frames := runtime.CallersFrames(pc[:n])
  176. var frame runtime.Frame
  177. for {
  178. frame, more := frames.Next()
  179. if !strings.HasPrefix(frame.Function, "net/http.") {
  180. return frame
  181. }
  182. if !more {
  183. break
  184. }
  185. }
  186. return frame
  187. }