timeouthandler.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package handler
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "path"
  10. "runtime"
  11. "strings"
  12. "sync"
  13. "time"
  14. "github.com/zeromicro/go-zero/rest/httpx"
  15. "github.com/zeromicro/go-zero/rest/internal"
  16. )
  17. const (
  18. statusClientClosedRequest = 499
  19. reason = "Request Timeout"
  20. headerUpgrade = "Upgrade"
  21. valueWebsocket = "websocket"
  22. )
  23. // TimeoutHandler returns the handler with given timeout.
  24. // If client closed request, code 499 will be logged.
  25. // Notice: even if canceled in server side, 499 will be logged as well.
  26. func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
  27. return func(next http.Handler) http.Handler {
  28. if duration > 0 {
  29. return &timeoutHandler{
  30. handler: next,
  31. dt: duration,
  32. }
  33. }
  34. return next
  35. }
  36. }
  37. // timeoutHandler is the handler that controls the request timeout.
  38. // Why we implement it on our own, because the stdlib implementation
  39. // treats the ClientClosedRequest as http.StatusServiceUnavailable.
  40. // And we write the codes in logs as code 499, which is defined by nginx.
  41. type timeoutHandler struct {
  42. handler http.Handler
  43. dt time.Duration
  44. }
  45. func (h *timeoutHandler) errorBody() string {
  46. return reason
  47. }
  48. func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  49. if r.Header.Get(headerUpgrade) == valueWebsocket {
  50. h.handler.ServeHTTP(w, r)
  51. return
  52. }
  53. ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt)
  54. defer cancelCtx()
  55. r = r.WithContext(ctx)
  56. done := make(chan struct{})
  57. tw := &timeoutWriter{
  58. w: w,
  59. h: make(http.Header),
  60. req: r,
  61. }
  62. panicChan := make(chan interface{}, 1)
  63. go func() {
  64. defer func() {
  65. if p := recover(); p != nil {
  66. panicChan <- p
  67. }
  68. }()
  69. h.handler.ServeHTTP(tw, r)
  70. close(done)
  71. }()
  72. select {
  73. case p := <-panicChan:
  74. panic(p)
  75. case <-done:
  76. tw.mu.Lock()
  77. defer tw.mu.Unlock()
  78. dst := w.Header()
  79. for k, vv := range tw.h {
  80. dst[k] = vv
  81. }
  82. if !tw.wroteHeader {
  83. tw.code = http.StatusOK
  84. }
  85. w.WriteHeader(tw.code)
  86. w.Write(tw.wbuf.Bytes())
  87. case <-ctx.Done():
  88. tw.mu.Lock()
  89. defer tw.mu.Unlock()
  90. // there isn't any user-defined middleware before TimoutHandler,
  91. // so we can guarantee that cancelation in biz related code won't come here.
  92. httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) {
  93. if errors.Is(err, context.Canceled) {
  94. w.WriteHeader(statusClientClosedRequest)
  95. } else {
  96. w.WriteHeader(http.StatusServiceUnavailable)
  97. }
  98. io.WriteString(w, h.errorBody())
  99. })
  100. tw.timedOut = true
  101. }
  102. }
  103. type timeoutWriter struct {
  104. w http.ResponseWriter
  105. h http.Header
  106. wbuf bytes.Buffer
  107. req *http.Request
  108. mu sync.Mutex
  109. timedOut bool
  110. wroteHeader bool
  111. code int
  112. }
  113. var _ http.Pusher = (*timeoutWriter)(nil)
  114. // Header returns the underline temporary http.Header.
  115. func (tw *timeoutWriter) Header() http.Header { return tw.h }
  116. // Push implements the Pusher interface.
  117. func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error {
  118. if pusher, ok := tw.w.(http.Pusher); ok {
  119. return pusher.Push(target, opts)
  120. }
  121. return http.ErrNotSupported
  122. }
  123. // Write writes the data to the connection as part of an HTTP reply.
  124. // Timeout and multiple header written are guarded.
  125. func (tw *timeoutWriter) Write(p []byte) (int, error) {
  126. tw.mu.Lock()
  127. defer tw.mu.Unlock()
  128. if tw.timedOut {
  129. return 0, http.ErrHandlerTimeout
  130. }
  131. if !tw.wroteHeader {
  132. tw.writeHeaderLocked(http.StatusOK)
  133. }
  134. return tw.wbuf.Write(p)
  135. }
  136. func (tw *timeoutWriter) writeHeaderLocked(code int) {
  137. checkWriteHeaderCode(code)
  138. switch {
  139. case tw.timedOut:
  140. return
  141. case tw.wroteHeader:
  142. if tw.req != nil {
  143. caller := relevantCaller()
  144. internal.Errorf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)",
  145. caller.Function, path.Base(caller.File), caller.Line)
  146. }
  147. default:
  148. tw.wroteHeader = true
  149. tw.code = code
  150. }
  151. }
  152. func (tw *timeoutWriter) WriteHeader(code int) {
  153. tw.mu.Lock()
  154. defer tw.mu.Unlock()
  155. tw.writeHeaderLocked(code)
  156. }
  157. func checkWriteHeaderCode(code int) {
  158. if code < 100 || code > 599 {
  159. panic(fmt.Sprintf("invalid WriteHeader code %v", code))
  160. }
  161. }
  162. // relevantCaller searches the call stack for the first function outside of net/http.
  163. // The purpose of this function is to provide more helpful error messages.
  164. func relevantCaller() runtime.Frame {
  165. pc := make([]uintptr, 16)
  166. n := runtime.Callers(1, pc)
  167. frames := runtime.CallersFrames(pc[:n])
  168. var frame runtime.Frame
  169. for {
  170. frame, more := frames.Next()
  171. if !strings.HasPrefix(frame.Function, "net/http.") {
  172. return frame
  173. }
  174. if !more {
  175. break
  176. }
  177. }
  178. return frame
  179. }