Bladeren bron

chore: avoid superfluous WriteHeader call errors (#1275)

Kevin Wan 3 jaren geleden
bovenliggende
commit
c800f6f723
2 gewijzigde bestanden met toevoegingen van 97 en 6 verwijderingen
  1. 50 6
      rest/internal/cors/handlers.go
  2. 47 0
      rest/internal/cors/handlers_test.go

+ 50 - 6
rest/internal/cors/handlers.go

@@ -1,6 +1,11 @@
 package cors
 
-import "net/http"
+import (
+	"bufio"
+	"errors"
+	"net"
+	"net/http"
+)
 
 const (
 	allowOrigin      = "Access-Control-Allow-Origin"
@@ -25,15 +30,16 @@ const (
 // At most one origin can be specified, other origins are ignored if given, default to be *.
 func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		checkAndSetHeaders(w, r, origins)
+		gw := &guardedResponseWriter{w: w}
+		checkAndSetHeaders(gw, r, origins)
 		if fn != nil {
-			fn(w)
+			fn(gw)
 		}
 
-		if r.Method != http.MethodOptions {
-			w.WriteHeader(http.StatusNotFound)
+		if r.Method == http.MethodOptions {
+			gw.WriteHeader(http.StatusNoContent)
 		} else {
-			w.WriteHeader(http.StatusNoContent)
+			gw.WriteHeader(http.StatusNotFound)
 		}
 	})
 }
@@ -56,6 +62,44 @@ func Middleware(fn func(w http.ResponseWriter), origins ...string) func(http.Han
 	}
 }
 
+type guardedResponseWriter struct {
+	w           http.ResponseWriter
+	wroteHeader bool
+}
+
+func (w *guardedResponseWriter) Flush() {
+	if flusher, ok := w.w.(http.Flusher); ok {
+		flusher.Flush()
+	}
+}
+
+func (w *guardedResponseWriter) Header() http.Header {
+	return w.w.Header()
+}
+
+// Hijack implements the http.Hijacker interface.
+// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
+func (w *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	if hijacked, ok := w.w.(http.Hijacker); ok {
+		return hijacked.Hijack()
+	}
+
+	return nil, nil, errors.New("server doesn't support hijacking")
+}
+
+func (w *guardedResponseWriter) Write(bytes []byte) (int, error) {
+	return w.w.Write(bytes)
+}
+
+func (w *guardedResponseWriter) WriteHeader(code int) {
+	if w.wroteHeader {
+		return
+	}
+
+	w.w.WriteHeader(code)
+	w.wroteHeader = true
+}
+
 func checkAndSetHeaders(w http.ResponseWriter, r *http.Request, origins []string) {
 	setVaryHeaders(w, r)
 

+ 47 - 0
rest/internal/cors/handlers_test.go

@@ -1,6 +1,8 @@
 package cors
 
 import (
+	"bufio"
+	"net"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -129,3 +131,48 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
 		}
 	}
 }
+
+func TestGuardedResponseWriter_Flush(t *testing.T) {
+	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
+	handler := NotAllowedHandler(func(w http.ResponseWriter) {
+		w.Header().Set("X-Test", "test")
+		w.WriteHeader(http.StatusServiceUnavailable)
+		_, err := w.Write([]byte("content"))
+		assert.Nil(t, err)
+
+		flusher, ok := w.(http.Flusher)
+		assert.True(t, ok)
+		flusher.Flush()
+	}, "foo.com")
+
+	resp := httptest.NewRecorder()
+	handler.ServeHTTP(resp, req)
+	assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
+	assert.Equal(t, "test", resp.Header().Get("X-Test"))
+	assert.Equal(t, "content", resp.Body.String())
+}
+
+func TestGuardedResponseWriter_Hijack(t *testing.T) {
+	resp := httptest.NewRecorder()
+	writer := &guardedResponseWriter{
+		w: resp,
+	}
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+
+	writer = &guardedResponseWriter{
+		w: mockedHijackable{resp},
+	}
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+}
+
+type mockedHijackable struct {
+	*httptest.ResponseRecorder
+}
+
+func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	return nil, nil, nil
+}