Browse Source

fix: ignore timeout on websocket (#1802)

Kevin Wan 3 years ago
parent
commit
92b450eb11
2 changed files with 20 additions and 0 deletions
  1. 7 0
      rest/handler/timeouthandler.go
  2. 13 0
      rest/handler/timeouthandler_test.go

+ 7 - 0
rest/handler/timeouthandler.go

@@ -20,6 +20,8 @@ import (
 const (
 const (
 	statusClientClosedRequest = 499
 	statusClientClosedRequest = 499
 	reason                    = "Request Timeout"
 	reason                    = "Request Timeout"
+	headerUpgrade             = "Upgrade"
+	valueWebsocket            = "websocket"
 )
 )
 
 
 // TimeoutHandler returns the handler with given timeout.
 // TimeoutHandler returns the handler with given timeout.
@@ -52,6 +54,11 @@ func (h *timeoutHandler) errorBody() string {
 }
 }
 
 
 func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	if r.Header.Get(headerUpgrade) == valueWebsocket {
+		h.handler.ServeHTTP(w, r)
+		return
+	}
+
 	ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt)
 	ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt)
 	defer cancelCtx()
 	defer cancelCtx()
 
 

+ 13 - 0
rest/handler/timeouthandler_test.go

@@ -79,6 +79,19 @@ func TestTimeoutPanic(t *testing.T) {
 	})
 	})
 }
 }
 
 
+func TestTimeoutWebsocket(t *testing.T) {
+	timeoutHandler := TimeoutHandler(time.Millisecond)
+	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		time.Sleep(time.Millisecond * 10)
+	}))
+
+	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
+	req.Header.Set(headerUpgrade, valueWebsocket)
+	resp := httptest.NewRecorder()
+	handler.ServeHTTP(resp, req)
+	assert.Equal(t, http.StatusOK, resp.Code)
+}
+
 func TestTimeoutWroteHeaderTwice(t *testing.T) {
 func TestTimeoutWroteHeaderTwice(t *testing.T) {
 	timeoutHandler := TimeoutHandler(time.Minute)
 	timeoutHandler := TimeoutHandler(time.Minute)
 	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {