timeouthandler.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. package handler
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "path"
  10. "runtime"
  11. "strings"
  12. "sync"
  13. "time"
  14. "github.com/zeromicro/go-zero/rest/httpx"
  15. "github.com/zeromicro/go-zero/rest/internal"
  16. )
  17. const (
  18. statusClientClosedRequest = 499
  19. reason = "Request Timeout"
  20. )
  21. // TimeoutHandler returns the handler with given timeout.
  22. // If client closed request, code 499 will be logged.
  23. // Notice: even if canceled in server side, 499 will be logged as well.
  24. func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
  25. return func(next http.Handler) http.Handler {
  26. if duration > 0 {
  27. return &timeoutHandler{
  28. handler: next,
  29. dt: duration,
  30. }
  31. }
  32. return next
  33. }
  34. }
  35. // timeoutHandler is the handler that controls the request timeout.
  36. // Why we implement it on our own, because the stdlib implementation
  37. // treats the ClientClosedRequest as http.StatusServiceUnavailable.
  38. // And we write the codes in logs as code 499, which is defined by nginx.
  39. type timeoutHandler struct {
  40. handler http.Handler
  41. dt time.Duration
  42. }
  43. func (h *timeoutHandler) errorBody() string {
  44. return reason
  45. }
  46. func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  47. ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt)
  48. defer cancelCtx()
  49. r = r.WithContext(ctx)
  50. done := make(chan struct{})
  51. tw := &timeoutWriter{
  52. w: w,
  53. h: make(http.Header),
  54. req: r,
  55. }
  56. panicChan := make(chan interface{}, 1)
  57. go func() {
  58. defer func() {
  59. if p := recover(); p != nil {
  60. panicChan <- p
  61. }
  62. }()
  63. h.handler.ServeHTTP(tw, r)
  64. close(done)
  65. }()
  66. select {
  67. case p := <-panicChan:
  68. panic(p)
  69. case <-done:
  70. tw.mu.Lock()
  71. defer tw.mu.Unlock()
  72. dst := w.Header()
  73. for k, vv := range tw.h {
  74. dst[k] = vv
  75. }
  76. if !tw.wroteHeader {
  77. tw.code = http.StatusOK
  78. }
  79. w.WriteHeader(tw.code)
  80. w.Write(tw.wbuf.Bytes())
  81. case <-ctx.Done():
  82. tw.mu.Lock()
  83. defer tw.mu.Unlock()
  84. // there isn't any user-defined middleware before TimoutHandler,
  85. // so we can guarantee that cancelation in biz related code won't come here.
  86. httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) {
  87. if errors.Is(err, context.Canceled) {
  88. w.WriteHeader(statusClientClosedRequest)
  89. } else {
  90. w.WriteHeader(http.StatusServiceUnavailable)
  91. }
  92. io.WriteString(w, h.errorBody())
  93. })
  94. tw.timedOut = true
  95. }
  96. }
  97. type timeoutWriter struct {
  98. w http.ResponseWriter
  99. h http.Header
  100. wbuf bytes.Buffer
  101. req *http.Request
  102. mu sync.Mutex
  103. timedOut bool
  104. wroteHeader bool
  105. code int
  106. }
  107. var _ http.Pusher = (*timeoutWriter)(nil)
  108. // Header returns the underline temporary http.Header.
  109. func (tw *timeoutWriter) Header() http.Header { return tw.h }
  110. // Push implements the Pusher interface.
  111. func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error {
  112. if pusher, ok := tw.w.(http.Pusher); ok {
  113. return pusher.Push(target, opts)
  114. }
  115. return http.ErrNotSupported
  116. }
  117. // Write writes the data to the connection as part of an HTTP reply.
  118. // Timeout and multiple header written are guarded.
  119. func (tw *timeoutWriter) Write(p []byte) (int, error) {
  120. tw.mu.Lock()
  121. defer tw.mu.Unlock()
  122. if tw.timedOut {
  123. return 0, http.ErrHandlerTimeout
  124. }
  125. if !tw.wroteHeader {
  126. tw.writeHeaderLocked(http.StatusOK)
  127. }
  128. return tw.wbuf.Write(p)
  129. }
  130. func (tw *timeoutWriter) writeHeaderLocked(code int) {
  131. checkWriteHeaderCode(code)
  132. switch {
  133. case tw.timedOut:
  134. return
  135. case tw.wroteHeader:
  136. if tw.req != nil {
  137. caller := relevantCaller()
  138. internal.Errorf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)",
  139. caller.Function, path.Base(caller.File), caller.Line)
  140. }
  141. default:
  142. tw.wroteHeader = true
  143. tw.code = code
  144. }
  145. }
  146. func (tw *timeoutWriter) WriteHeader(code int) {
  147. tw.mu.Lock()
  148. defer tw.mu.Unlock()
  149. tw.writeHeaderLocked(code)
  150. }
  151. func checkWriteHeaderCode(code int) {
  152. if code < 100 || code > 599 {
  153. panic(fmt.Sprintf("invalid WriteHeader code %v", code))
  154. }
  155. }
  156. // relevantCaller searches the call stack for the first function outside of net/http.
  157. // The purpose of this function is to provide more helpful error messages.
  158. func relevantCaller() runtime.Frame {
  159. pc := make([]uintptr, 16)
  160. n := runtime.Callers(1, pc)
  161. frames := runtime.CallersFrames(pc[:n])
  162. var frame runtime.Frame
  163. for {
  164. frame, more := frames.Next()
  165. if !strings.HasPrefix(frame.Function, "net/http.") {
  166. return frame
  167. }
  168. if !more {
  169. break
  170. }
  171. }
  172. return frame
  173. }