Browse Source

refactor(rest): keep rest log collector context key private (#3407)

cong 1 năm trước cách đây
mục cha
commit
61e562d0c7

+ 2 - 3
rest/handler/loghandler.go

@@ -3,7 +3,6 @@ package handler
 import (
 	"bufio"
 	"bytes"
-	"context"
 	"errors"
 	"fmt"
 	"io"
@@ -44,7 +43,7 @@ func LogHandler(next http.Handler) http.Handler {
 
 		var dup io.ReadCloser
 		r.Body, dup = iox.DupReadCloser(r.Body)
-		next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
+		next.ServeHTTP(&lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
 		r.Body = dup
 		logBrief(r, lrw.Code, timer, logs)
 	})
@@ -102,7 +101,7 @@ func DetailedLogHandler(next http.Handler) http.Handler {
 		var dup io.ReadCloser
 		r.Body, dup = iox.DupReadCloser(r.Body)
 		logs := new(internal.LogCollector)
-		next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
+		next.ServeHTTP(lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
 		r.Body = dup
 		logDetails(r, lrw, timer, logs)
 	})

+ 2 - 2
rest/handler/loghandler_test.go

@@ -22,7 +22,7 @@ func TestLogHandler(t *testing.T) {
 	for _, logHandler := range handlers {
 		req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
 		handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-			r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything")
+			internal.LogCollectorFromContext(r.Context()).Append("anything")
 			w.Header().Set("X-Test", "test")
 			w.WriteHeader(http.StatusServiceUnavailable)
 			_, err := w.Write([]byte("content"))
@@ -49,7 +49,7 @@ func TestLogHandlerVeryLong(t *testing.T) {
 
 	req := httptest.NewRequest(http.MethodPost, "http://localhost", &buf)
 	handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything")
+		internal.LogCollectorFromContext(r.Context()).Append("anything")
 		_, _ = io.Copy(io.Discard, r.Body)
 		w.Header().Set("X-Test", "test")
 		w.WriteHeader(http.StatusServiceUnavailable)

+ 28 - 14
rest/internal/log.go

@@ -2,6 +2,7 @@ package internal
 
 import (
 	"bytes"
+	"context"
 	"fmt"
 	"net/http"
 	"sync"
@@ -10,13 +11,32 @@ import (
 	"github.com/zeromicro/go-zero/rest/httpx"
 )
 
-// LogContext is a context key.
-var LogContext = contextKey("request_logs")
+// logContextKey is a context key.
+var logContextKey = contextKey("request_logs")
 
-// A LogCollector is used to collect logs.
-type LogCollector struct {
-	Messages []string
-	lock     sync.Mutex
+type (
+	// LogCollector is used to collect logs.
+	LogCollector struct {
+		Messages []string
+		lock     sync.Mutex
+	}
+
+	contextKey string
+)
+
+// WithLogCollector returns a new context with LogCollector.
+func WithLogCollector(ctx context.Context, lc *LogCollector) context.Context {
+	return context.WithValue(ctx, logContextKey, lc)
+}
+
+// LogCollectorFromContext returns LogCollector from ctx.
+func LogCollectorFromContext(ctx context.Context) *LogCollector {
+	val := ctx.Value(logContextKey)
+	if val == nil {
+		return nil
+	}
+
+	return val.(*LogCollector)
 }
 
 // Append appends msg into log context.
@@ -73,9 +93,9 @@ func Infof(r *http.Request, format string, v ...any) {
 }
 
 func appendLog(r *http.Request, message string) {
-	logs := r.Context().Value(LogContext)
+	logs := LogCollectorFromContext(r.Context())
 	if logs != nil {
-		logs.(*LogCollector).Append(message)
+		logs.Append(message)
 	}
 }
 
@@ -90,9 +110,3 @@ func formatf(r *http.Request, format string, v ...any) string {
 func formatWithReq(r *http.Request, v string) string {
 	return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v)
 }
-
-type contextKey string
-
-func (c contextKey) String() string {
-	return "rest/internal context key " + string(c)
-}

+ 7 - 4
rest/internal/log_test.go

@@ -14,7 +14,7 @@ import (
 func TestInfo(t *testing.T) {
 	collector := new(LogCollector)
 	req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
-	req = req.WithContext(context.WithValue(req.Context(), LogContext, collector))
+	req = req.WithContext(WithLogCollector(req.Context(), collector))
 	Info(req, "first")
 	Infof(req, "second %s", "third")
 	val := collector.Flush()
@@ -35,7 +35,10 @@ func TestError(t *testing.T) {
 	assert.True(t, strings.Contains(val, "third"))
 }
 
-func TestContextKey_String(t *testing.T) {
-	val := contextKey("foo")
-	assert.True(t, strings.Contains(val.String(), "foo"))
+func TestLogCollectorContext(t *testing.T) {
+	ctx := context.Background()
+	assert.Nil(t, LogCollectorFromContext(ctx))
+	collector := new(LogCollector)
+	ctx = WithLogCollector(ctx, collector)
+	assert.Equal(t, collector, LogCollectorFromContext(ctx))
 }