timeouthandler.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. package handler
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "net/http"
  11. "path"
  12. "runtime"
  13. "strings"
  14. "sync"
  15. "time"
  16. "github.com/wuntsong-org/go-zero-plus/rest/httpx"
  17. "github.com/wuntsong-org/go-zero-plus/rest/internal"
  18. )
  19. const (
  20. statusClientClosedRequest = 499
  21. reason = "Request Timeout"
  22. headerUpgrade = "Upgrade"
  23. valueWebsocket = "websocket"
  24. )
  25. // TimeoutHandler returns the handler with given timeout.
  26. // If client closed request, code 499 will be logged.
  27. // Notice: even if canceled in server side, 499 will be logged as well.
  28. func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
  29. return func(next http.Handler) http.Handler {
  30. if duration <= 0 {
  31. return next
  32. }
  33. return &timeoutHandler{
  34. handler: next,
  35. dt: duration,
  36. }
  37. }
  38. }
  39. // timeoutHandler is the handler that controls the request timeout.
  40. // Why we implement it on our own, because the stdlib implementation
  41. // treats the ClientClosedRequest as http.StatusServiceUnavailable.
  42. // And we write the codes in logs as code 499, which is defined by nginx.
  43. type timeoutHandler struct {
  44. handler http.Handler
  45. dt time.Duration
  46. }
  47. func (h *timeoutHandler) errorBody() string {
  48. return reason
  49. }
  50. func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  51. if r.Header.Get(headerUpgrade) == valueWebsocket {
  52. h.handler.ServeHTTP(w, r)
  53. return
  54. }
  55. ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt)
  56. defer cancelCtx()
  57. r = r.WithContext(ctx)
  58. done := make(chan struct{})
  59. tw := &timeoutWriter{
  60. w: w,
  61. h: make(http.Header),
  62. req: r,
  63. code: http.StatusOK,
  64. }
  65. panicChan := make(chan any, 1)
  66. go func() {
  67. defer func() {
  68. if p := recover(); p != nil {
  69. panicChan <- p
  70. }
  71. }()
  72. h.handler.ServeHTTP(tw, r)
  73. close(done)
  74. }()
  75. select {
  76. case p := <-panicChan:
  77. panic(p)
  78. case <-done:
  79. tw.mu.Lock()
  80. defer tw.mu.Unlock()
  81. dst := w.Header()
  82. for k, vv := range tw.h {
  83. dst[k] = vv
  84. }
  85. // We don't need to write header 200, because it's written by default.
  86. // If we write it again, it will cause a warning: `http: superfluous response.WriteHeader call`.
  87. if tw.code != http.StatusOK {
  88. w.WriteHeader(tw.code)
  89. }
  90. w.Write(tw.wbuf.Bytes())
  91. case <-ctx.Done():
  92. tw.mu.Lock()
  93. defer tw.mu.Unlock()
  94. // there isn't any user-defined middleware before TimoutHandler,
  95. // so we can guarantee that cancelation in biz related code won't come here.
  96. httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
  97. if errors.Is(err, context.Canceled) {
  98. w.WriteHeader(statusClientClosedRequest)
  99. } else {
  100. w.WriteHeader(http.StatusServiceUnavailable)
  101. }
  102. io.WriteString(w, h.errorBody())
  103. })
  104. tw.timedOut = true
  105. }
  106. }
  107. type timeoutWriter struct {
  108. w http.ResponseWriter
  109. h http.Header
  110. wbuf bytes.Buffer
  111. req *http.Request
  112. mu sync.Mutex
  113. timedOut bool
  114. wroteHeader bool
  115. code int
  116. }
  117. var _ http.Pusher = (*timeoutWriter)(nil)
  118. // Flush implements the Flusher interface.
  119. func (tw *timeoutWriter) Flush() {
  120. flusher, ok := tw.w.(http.Flusher)
  121. if !ok {
  122. return
  123. }
  124. header := tw.w.Header()
  125. for k, v := range tw.h {
  126. header[k] = v
  127. }
  128. tw.w.Write(tw.wbuf.Bytes())
  129. tw.wbuf.Reset()
  130. flusher.Flush()
  131. }
  132. // Header returns the underline temporary http.Header.
  133. func (tw *timeoutWriter) Header() http.Header {
  134. return tw.h
  135. }
  136. // Hijack implements the Hijacker interface.
  137. func (tw *timeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  138. if hijacked, ok := tw.w.(http.Hijacker); ok {
  139. return hijacked.Hijack()
  140. }
  141. return nil, nil, errors.New("server doesn't support hijacking")
  142. }
  143. // Push implements the Pusher interface.
  144. func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error {
  145. if pusher, ok := tw.w.(http.Pusher); ok {
  146. return pusher.Push(target, opts)
  147. }
  148. return http.ErrNotSupported
  149. }
  150. // Write writes the data to the connection as part of an HTTP reply.
  151. // Timeout and multiple header written are guarded.
  152. func (tw *timeoutWriter) Write(p []byte) (int, error) {
  153. tw.mu.Lock()
  154. defer tw.mu.Unlock()
  155. if tw.timedOut {
  156. return 0, http.ErrHandlerTimeout
  157. }
  158. if !tw.wroteHeader {
  159. tw.writeHeaderLocked(http.StatusOK)
  160. }
  161. return tw.wbuf.Write(p)
  162. }
  163. func (tw *timeoutWriter) writeHeaderLocked(code int) {
  164. checkWriteHeaderCode(code)
  165. switch {
  166. case tw.timedOut:
  167. return
  168. case tw.wroteHeader:
  169. if tw.req != nil {
  170. caller := relevantCaller()
  171. internal.Errorf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)",
  172. caller.Function, path.Base(caller.File), caller.Line)
  173. }
  174. default:
  175. tw.wroteHeader = true
  176. tw.code = code
  177. }
  178. }
  179. func (tw *timeoutWriter) WriteHeader(code int) {
  180. tw.mu.Lock()
  181. defer tw.mu.Unlock()
  182. if !tw.wroteHeader {
  183. tw.writeHeaderLocked(code)
  184. }
  185. }
  186. func checkWriteHeaderCode(code int) {
  187. if code < 100 || code > 599 {
  188. panic(fmt.Sprintf("invalid WriteHeader code %v", code))
  189. }
  190. }
  191. // relevantCaller searches the call stack for the first function outside of net/http.
  192. // The purpose of this function is to provide more helpful error messages.
  193. func relevantCaller() runtime.Frame {
  194. pc := make([]uintptr, 16)
  195. n := runtime.Callers(1, pc)
  196. frames := runtime.CallersFrames(pc[:n])
  197. var frame runtime.Frame
  198. for {
  199. frame, more := frames.Next()
  200. if !strings.HasPrefix(frame.Function, "net/http.") {
  201. return frame
  202. }
  203. if !more {
  204. break
  205. }
  206. }
  207. return frame
  208. }