Эх сурвалжийг харах

feat: treat client closed requests as code 499 (#1350)

* feat: treat client closed requests as code 499

* chore: add comments
Kevin Wan 3 жил өмнө
parent
commit
4ba2ff7cdd

+ 170 - 2
rest/handler/timeouthandler.go

@@ -1,19 +1,187 @@
 package handler
 
 import (
+	"bytes"
+	"context"
+	"errors"
+	"fmt"
+	"io"
 	"net/http"
+	"path"
+	"runtime"
+	"strings"
+	"sync"
 	"time"
+
+	"github.com/tal-tech/go-zero/rest/internal"
 )
 
-const reason = "Request Timeout"
+const (
+	statusClientClosedRequest = 499
+	reason                    = "Request Timeout"
+)
 
 // TimeoutHandler returns the handler with given timeout.
+// If client closed request, code 499 will be logged.
+// Notice: even if canceled in server side, 499 will be logged as well.
 func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
 	return func(next http.Handler) http.Handler {
 		if duration > 0 {
-			return http.TimeoutHandler(next, duration, reason)
+			return &timeoutHandler{
+				handler: next,
+				dt:      duration,
+			}
 		}
 
 		return next
 	}
 }
+
+// timeoutHandler is the handler that controls the request timeout.
+// Why we implement it on our own, because the stdlib implementation
+// treats the ClientClosedRequest as http.StatusServiceUnavailable.
+// And we write the codes in logs as code 499, which is defined by nginx.
+type timeoutHandler struct {
+	handler http.Handler
+	dt      time.Duration
+}
+
+func (h *timeoutHandler) errorBody() string {
+	return reason
+}
+
+func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt)
+	defer cancelCtx()
+
+	r = r.WithContext(ctx)
+	done := make(chan struct{})
+	tw := &timeoutWriter{
+		w:   w,
+		h:   make(http.Header),
+		req: r,
+	}
+	panicChan := make(chan interface{}, 1)
+	go func() {
+		defer func() {
+			if p := recover(); p != nil {
+				panicChan <- p
+			}
+		}()
+		h.handler.ServeHTTP(tw, r)
+		close(done)
+	}()
+	select {
+	case p := <-panicChan:
+		panic(p)
+	case <-done:
+		tw.mu.Lock()
+		defer tw.mu.Unlock()
+		dst := w.Header()
+		for k, vv := range tw.h {
+			dst[k] = vv
+		}
+		if !tw.wroteHeader {
+			tw.code = http.StatusOK
+		}
+		w.WriteHeader(tw.code)
+		w.Write(tw.wbuf.Bytes())
+	case <-ctx.Done():
+		tw.mu.Lock()
+		defer tw.mu.Unlock()
+		if errors.Is(ctx.Err(), context.Canceled) {
+			w.WriteHeader(statusClientClosedRequest)
+		} else {
+			w.WriteHeader(http.StatusServiceUnavailable)
+		}
+		io.WriteString(w, h.errorBody())
+		tw.timedOut = true
+	}
+}
+
+type timeoutWriter struct {
+	w    http.ResponseWriter
+	h    http.Header
+	wbuf bytes.Buffer
+	req  *http.Request
+
+	mu          sync.Mutex
+	timedOut    bool
+	wroteHeader bool
+	code        int
+}
+
+var _ http.Pusher = (*timeoutWriter)(nil)
+
+// Push implements the Pusher interface.
+func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error {
+	if pusher, ok := tw.w.(http.Pusher); ok {
+		return pusher.Push(target, opts)
+	}
+	return http.ErrNotSupported
+}
+
+func (tw *timeoutWriter) Header() http.Header { return tw.h }
+
+func (tw *timeoutWriter) Write(p []byte) (int, error) {
+	tw.mu.Lock()
+	defer tw.mu.Unlock()
+
+	if tw.timedOut {
+		return 0, http.ErrHandlerTimeout
+	}
+
+	if !tw.wroteHeader {
+		tw.writeHeaderLocked(http.StatusOK)
+	}
+	return tw.wbuf.Write(p)
+}
+
+func (tw *timeoutWriter) writeHeaderLocked(code int) {
+	checkWriteHeaderCode(code)
+
+	switch {
+	case tw.timedOut:
+		return
+	case tw.wroteHeader:
+		if tw.req != nil {
+			caller := relevantCaller()
+			internal.Errorf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)",
+				caller.Function, path.Base(caller.File), caller.Line)
+		}
+	default:
+		tw.wroteHeader = true
+		tw.code = code
+	}
+}
+
+func (tw *timeoutWriter) WriteHeader(code int) {
+	tw.mu.Lock()
+	defer tw.mu.Unlock()
+	tw.writeHeaderLocked(code)
+}
+
+func checkWriteHeaderCode(code int) {
+	if code < 100 || code > 599 {
+		panic(fmt.Sprintf("invalid WriteHeader code %v", code))
+	}
+}
+
+// relevantCaller searches the call stack for the first function outside of net/http.
+// The purpose of this function is to provide more helpful error messages.
+func relevantCaller() runtime.Frame {
+	pc := make([]uintptr, 16)
+	n := runtime.Callers(1, pc)
+	frames := runtime.CallersFrames(pc[:n])
+	var frame runtime.Frame
+	for {
+		frame, more := frames.Next()
+		if !strings.HasPrefix(frame.Function, "net/http.") {
+			return frame
+		}
+		if !more {
+			break
+		}
+	}
+	return frame
+}

+ 103 - 0
rest/handler/timeouthandler_test.go

@@ -1,6 +1,7 @@
 package handler
 
 import (
+	"context"
 	"io/ioutil"
 	"log"
 	"net/http"
@@ -39,6 +40,20 @@ func TestWithinTimeout(t *testing.T) {
 	assert.Equal(t, http.StatusOK, resp.Code)
 }
 
+func TestWithTimeoutTimedout(t *testing.T) {
+	timeoutHandler := TimeoutHandler(time.Millisecond)
+	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		time.Sleep(time.Millisecond * 10)
+		w.Write([]byte(`foo`))
+		w.WriteHeader(http.StatusOK)
+	}))
+
+	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
+	resp := httptest.NewRecorder()
+	handler.ServeHTTP(resp, req)
+	assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
+}
+
 func TestWithoutTimeout(t *testing.T) {
 	timeoutHandler := TimeoutHandler(0)
 	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -50,3 +65,91 @@ func TestWithoutTimeout(t *testing.T) {
 	handler.ServeHTTP(resp, req)
 	assert.Equal(t, http.StatusOK, resp.Code)
 }
+
+func TestTimeoutPanic(t *testing.T) {
+	timeoutHandler := TimeoutHandler(time.Minute)
+	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		panic("foo")
+	}))
+
+	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
+	resp := httptest.NewRecorder()
+	assert.Panics(t, func() {
+		handler.ServeHTTP(resp, req)
+	})
+}
+
+func TestTimeoutWroteHeaderTwice(t *testing.T) {
+	timeoutHandler := TimeoutHandler(time.Minute)
+	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Write([]byte(`hello`))
+		w.Header().Set("foo", "bar")
+		w.WriteHeader(http.StatusOK)
+	}))
+
+	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
+	resp := httptest.NewRecorder()
+	handler.ServeHTTP(resp, req)
+	assert.Equal(t, http.StatusOK, resp.Code)
+}
+
+func TestTimeoutWriteBadCode(t *testing.T) {
+	timeoutHandler := TimeoutHandler(time.Minute)
+	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(1000)
+	}))
+
+	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
+	resp := httptest.NewRecorder()
+	assert.Panics(t, func() {
+		handler.ServeHTTP(resp, req)
+	})
+}
+
+func TestTimeoutClientClosed(t *testing.T) {
+	timeoutHandler := TimeoutHandler(time.Minute)
+	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(1000)
+	}))
+
+	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
+	ctx, cancel := context.WithCancel(context.Background())
+	req = req.WithContext(ctx)
+	cancel()
+	resp := httptest.NewRecorder()
+	handler.ServeHTTP(resp, req)
+	assert.Equal(t, statusClientClosedRequest, resp.Code)
+}
+
+func TestTimeoutPusher(t *testing.T) {
+	handler := &timeoutWriter{
+		w: mockedPusher{},
+	}
+
+	assert.Panics(t, func() {
+		handler.Push("any", nil)
+	})
+
+	handler = &timeoutWriter{
+		w: httptest.NewRecorder(),
+	}
+	assert.Equal(t, http.ErrNotSupported, handler.Push("any", nil))
+}
+
+type mockedPusher struct{}
+
+func (m mockedPusher) Header() http.Header {
+	panic("implement me")
+}
+
+func (m mockedPusher) Write(bytes []byte) (int, error) {
+	panic("implement me")
+}
+
+func (m mockedPusher) WriteHeader(statusCode int) {
+	panic("implement me")
+}
+
+func (m mockedPusher) Push(target string, opts *http.PushOptions) error {
+	panic("implement me")
+}