Przeglądaj źródła

chore: avoid nested WithCodeResponseWriter (#3406)

Kevin Wan 1 rok temu
rodzic
commit
13cdbdc98b

+ 1 - 1
rest/handler/breakerhandler.go

@@ -28,7 +28,7 @@ func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handle
 				return
 			}
 
-			cw := &response.WithCodeResponseWriter{Writer: w}
+			cw := response.NewWithCodeResponseWriter(w)
 			defer func() {
 				if cw.Code < http.StatusInternalServerError {
 					promise.Accept()

+ 6 - 10
rest/handler/loghandler.go

@@ -36,14 +36,11 @@ func LogHandler(next http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		timer := utils.NewElapsedTimer()
 		logs := new(internal.LogCollector)
-		lrw := response.WithCodeResponseWriter{
-			Writer: w,
-			Code:   http.StatusOK,
-		}
+		lrw := response.NewWithCodeResponseWriter(w)
 
 		var dup io.ReadCloser
 		r.Body, dup = iox.DupReadCloser(r.Body)
-		next.ServeHTTP(&lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
+		next.ServeHTTP(lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
 		r.Body = dup
 		logBrief(r, lrw.Code, timer, logs)
 	})
@@ -54,7 +51,8 @@ type detailLoggedResponseWriter struct {
 	buf    *bytes.Buffer
 }
 
-func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter, buf *bytes.Buffer) *detailLoggedResponseWriter {
+func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter,
+	buf *bytes.Buffer) *detailLoggedResponseWriter {
 	return &detailLoggedResponseWriter{
 		writer: writer,
 		buf:    buf,
@@ -93,10 +91,8 @@ func DetailedLogHandler(next http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		timer := utils.NewElapsedTimer()
 		var buf bytes.Buffer
-		lrw := newDetailLoggedResponseWriter(&response.WithCodeResponseWriter{
-			Writer: w,
-			Code:   http.StatusOK,
-		}, &buf)
+		rw := response.NewWithCodeResponseWriter(w)
+		lrw := newDetailLoggedResponseWriter(rw, &buf)
 
 		var dup io.ReadCloser
 		r.Body, dup = iox.DupReadCloser(r.Body)

+ 31 - 6
rest/handler/loghandler_test.go

@@ -2,6 +2,7 @@ package handler
 
 import (
 	"bytes"
+	"errors"
 	"io"
 	"net/http"
 	"net/http/httptest"
@@ -88,18 +89,23 @@ func TestLogHandlerSlow(t *testing.T) {
 func TestDetailedLogHandler_Hijack(t *testing.T) {
 	resp := httptest.NewRecorder()
 	writer := &detailLoggedResponseWriter{
-		writer: &response.WithCodeResponseWriter{
-			Writer: resp,
-		},
+		writer: response.NewWithCodeResponseWriter(resp),
 	}
 	assert.NotPanics(t, func() {
 		_, _, _ = writer.Hijack()
 	})
 
 	writer = &detailLoggedResponseWriter{
-		writer: &response.WithCodeResponseWriter{
-			Writer: mockedHijackable{resp},
-		},
+		writer: response.NewWithCodeResponseWriter(resp),
+	}
+	assert.NotPanics(t, func() {
+		_, _, _ = writer.Hijack()
+	})
+
+	writer = &detailLoggedResponseWriter{
+		writer: response.NewWithCodeResponseWriter(mockedHijackable{
+			ResponseRecorder: resp,
+		}),
 	}
 	assert.NotPanics(t, func() {
 		_, _, _ = writer.Hijack()
@@ -133,6 +139,13 @@ func TestWrapStatusCodeWithColor(t *testing.T) {
 	assert.Equal(t, "503", wrapStatusCode(http.StatusServiceUnavailable))
 }
 
+func TestDumpRequest(t *testing.T) {
+	const errMsg = "error"
+	r := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
+	r.Body = mockedReadCloser{errMsg: errMsg}
+	assert.Equal(t, errMsg, dumpRequest(r))
+}
+
 func BenchmarkLogHandler(b *testing.B) {
 	b.ReportAllocs()
 
@@ -146,3 +159,15 @@ func BenchmarkLogHandler(b *testing.B) {
 		handler.ServeHTTP(resp, req)
 	}
 }
+
+type mockedReadCloser struct {
+	errMsg string
+}
+
+func (m mockedReadCloser) Read(p []byte) (n int, err error) {
+	return 0, errors.New(m.errMsg)
+}
+
+func (m mockedReadCloser) Close() error {
+	return nil
+}

+ 1 - 1
rest/handler/prometheushandler.go

@@ -35,7 +35,7 @@ func PrometheusHandler(path, method string) func(http.Handler) http.Handler {
 	return func(next http.Handler) http.Handler {
 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			startTime := timex.Now()
-			cw := &response.WithCodeResponseWriter{Writer: w}
+			cw := response.NewWithCodeResponseWriter(w)
 			defer func() {
 				metricServerReqDur.Observe(timex.Since(startTime).Milliseconds(), path, method)
 				metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code), method)

+ 1 - 1
rest/handler/sheddinghandler.go

@@ -41,7 +41,7 @@ func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Hand
 				return
 			}
 
-			cw := &response.WithCodeResponseWriter{Writer: w}
+			cw := response.NewWithCodeResponseWriter(w)
 			defer func() {
 				if cw.Code == http.StatusServiceUnavailable {
 					promise.Fail()

+ 9 - 6
rest/handler/timeouthandler.go

@@ -67,9 +67,10 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	r = r.WithContext(ctx)
 	done := make(chan struct{})
 	tw := &timeoutWriter{
-		w:   w,
-		h:   make(http.Header),
-		req: r,
+		w:    w,
+		h:    make(http.Header),
+		req:  r,
+		code: http.StatusOK,
 	}
 	panicChan := make(chan any, 1)
 	go func() {
@@ -91,10 +92,12 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		for k, vv := range tw.h {
 			dst[k] = vv
 		}
-		if !tw.wroteHeader {
-			tw.code = http.StatusOK
+
+		// We don't need to write header 200, because it's written by default.
+		// If we write it again, it will cause a warning: `http: superfluous response.WriteHeader call`.
+		if tw.code != http.StatusOK {
+			w.WriteHeader(tw.code)
 		}
-		w.WriteHeader(tw.code)
 		w.Write(tw.wbuf.Bytes())
 	case <-ctx.Done():
 		tw.mu.Lock()

+ 15 - 9
rest/handler/timeouthandler_test.go

@@ -100,6 +100,18 @@ func TestWithinTimeout(t *testing.T) {
 	assert.Equal(t, http.StatusOK, resp.Code)
 }
 
+func TestWithinTimeoutBadCode(t *testing.T) {
+	timeoutHandler := TimeoutHandler(time.Second)
+	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(http.StatusInternalServerError)
+	}))
+
+	req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
+	resp := httptest.NewRecorder()
+	handler.ServeHTTP(resp, req)
+	assert.Equal(t, http.StatusInternalServerError, resp.Code)
+}
+
 func TestWithTimeoutTimedout(t *testing.T) {
 	timeoutHandler := TimeoutHandler(time.Millisecond)
 	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -208,9 +220,7 @@ func TestTimeoutHijack(t *testing.T) {
 	resp := httptest.NewRecorder()
 
 	writer := &timeoutWriter{
-		w: &response.WithCodeResponseWriter{
-			Writer: resp,
-		},
+		w: response.NewWithCodeResponseWriter(resp),
 	}
 
 	assert.NotPanics(t, func() {
@@ -218,9 +228,7 @@ func TestTimeoutHijack(t *testing.T) {
 	})
 
 	writer = &timeoutWriter{
-		w: &response.WithCodeResponseWriter{
-			Writer: mockedHijackable{resp},
-		},
+		w: response.NewWithCodeResponseWriter(mockedHijackable{resp}),
 	}
 
 	assert.NotPanics(t, func() {
@@ -274,9 +282,7 @@ func TestTimeoutWriter_Hijack(t *testing.T) {
 func TestTimeoutWroteTwice(t *testing.T) {
 	c := logtest.NewCollector(t)
 	writer := &timeoutWriter{
-		w: &response.WithCodeResponseWriter{
-			Writer: httptest.NewRecorder(),
-		},
+		w:   response.NewWithCodeResponseWriter(httptest.NewRecorder()),
 		h:   make(http.Header),
 		req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
 	}

+ 1 - 1
rest/handler/tracehandler.go

@@ -60,7 +60,7 @@ func TraceHandler(serviceName, path string, opts ...TraceOption) func(http.Handl
 			// convenient for tracking error messages
 			propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
 
-			trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK}
+			trw := response.NewWithCodeResponseWriter(w)
 			next.ServeHTTP(trw, r.WithContext(spanCtx))
 
 			span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...)

+ 14 - 0
rest/internal/response/withcoderesponsewriter.go

@@ -13,6 +13,20 @@ type WithCodeResponseWriter struct {
 	Code   int
 }
 
+// NewWithCodeResponseWriter returns a WithCodeResponseWriter.
+// If writer is already a WithCodeResponseWriter, it returns writer directly.
+func NewWithCodeResponseWriter(writer http.ResponseWriter) *WithCodeResponseWriter {
+	switch w := writer.(type) {
+	case *WithCodeResponseWriter:
+		return w
+	default:
+		return &WithCodeResponseWriter{
+			Writer: writer,
+			Code:   http.StatusOK,
+		}
+	}
+}
+
 // Flush flushes the response writer.
 func (w *WithCodeResponseWriter) Flush() {
 	if flusher, ok := w.Writer.(http.Flusher); ok {

+ 2 - 4
rest/internal/response/withcoderesponsewriter_test.go

@@ -11,7 +11,7 @@ import (
 func TestWithCodeResponseWriter(t *testing.T) {
 	req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
 	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		cw := &WithCodeResponseWriter{Writer: w}
+		cw := NewWithCodeResponseWriter(w)
 
 		cw.Header().Set("X-Test", "test")
 		cw.WriteHeader(http.StatusServiceUnavailable)
@@ -34,9 +34,7 @@ func TestWithCodeResponseWriter(t *testing.T) {
 
 func TestWithCodeResponseWriter_Hijack(t *testing.T) {
 	resp := httptest.NewRecorder()
-	writer := &WithCodeResponseWriter{
-		Writer: resp,
-	}
+	writer := NewWithCodeResponseWriter(NewWithCodeResponseWriter(resp))
 	assert.NotPanics(t, func() {
 		writer.Hijack()
 	})