Răsfoiți Sursa

Add traceId to the response headers (#919)

* Add traceId to the request headers

* Add test cases

* Update refactor code
chenquan 3 ani în urmă
părinte
comite
7c842f22d0

+ 1 - 1
core/trace/constants.go

@@ -1,6 +1,6 @@
 package trace
 
 const (
-	traceIdKey = "X-Trace-ID"
+	TraceIdKey = "X-Trace-ID"
 	spanIdKey  = "X-Span-ID"
 )

+ 8 - 8
core/trace/propagator_test.go

@@ -11,11 +11,11 @@ import (
 
 func TestHttpPropagator_Extract(t *testing.T) {
 	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
-	req.Header.Set(traceIdKey, "trace")
+	req.Header.Set(TraceIdKey, "trace")
 	req.Header.Set(spanIdKey, "span")
 	carrier, err := Extract(HttpFormat, req.Header)
 	assert.Nil(t, err)
-	assert.Equal(t, "trace", carrier.Get(traceIdKey))
+	assert.Equal(t, "trace", carrier.Get(TraceIdKey))
 	assert.Equal(t, "span", carrier.Get(spanIdKey))
 
 	_, err = Extract(HttpFormat, req)
@@ -24,11 +24,11 @@ func TestHttpPropagator_Extract(t *testing.T) {
 
 func TestHttpPropagator_Inject(t *testing.T) {
 	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
-	req.Header.Set(traceIdKey, "trace")
+	req.Header.Set(TraceIdKey, "trace")
 	req.Header.Set(spanIdKey, "span")
 	carrier, err := Inject(HttpFormat, req.Header)
 	assert.Nil(t, err)
-	assert.Equal(t, "trace", carrier.Get(traceIdKey))
+	assert.Equal(t, "trace", carrier.Get(TraceIdKey))
 	assert.Equal(t, "span", carrier.Get(spanIdKey))
 
 	_, err = Inject(HttpFormat, req)
@@ -37,12 +37,12 @@ func TestHttpPropagator_Inject(t *testing.T) {
 
 func TestGrpcPropagator_Extract(t *testing.T) {
 	md := metadata.New(map[string]string{
-		traceIdKey: "trace",
+		TraceIdKey: "trace",
 		spanIdKey:  "span",
 	})
 	carrier, err := Extract(GrpcFormat, md)
 	assert.Nil(t, err)
-	assert.Equal(t, "trace", carrier.Get(traceIdKey))
+	assert.Equal(t, "trace", carrier.Get(TraceIdKey))
 	assert.Equal(t, "span", carrier.Get(spanIdKey))
 
 	_, err = Extract(GrpcFormat, 1)
@@ -53,12 +53,12 @@ func TestGrpcPropagator_Extract(t *testing.T) {
 
 func TestGrpcPropagator_Inject(t *testing.T) {
 	md := metadata.New(map[string]string{
-		traceIdKey: "trace",
+		TraceIdKey: "trace",
 		spanIdKey:  "span",
 	})
 	carrier, err := Inject(GrpcFormat, md)
 	assert.Nil(t, err)
-	assert.Equal(t, "trace", carrier.Get(traceIdKey))
+	assert.Equal(t, "trace", carrier.Get(TraceIdKey))
 	assert.Equal(t, "span", carrier.Get(spanIdKey))
 
 	_, err = Inject(GrpcFormat, 1)

+ 1 - 1
core/trace/span.go

@@ -34,7 +34,7 @@ type Span struct {
 func newServerSpan(carrier Carrier, serviceName, operationName string) tracespec.Trace {
 	traceId := stringx.TakeWithPriority(func() string {
 		if carrier != nil {
-			return carrier.Get(traceIdKey)
+			return carrier.Get(TraceIdKey)
 		}
 		return ""
 	}, stringx.RandId)

+ 2 - 2
core/trace/span_test.go

@@ -57,7 +57,7 @@ func TestServerSpan(t *testing.T) {
 
 func TestServerSpan_WithCarrier(t *testing.T) {
 	md := metadata.New(map[string]string{
-		traceIdKey: "a",
+		TraceIdKey: "a",
 		spanIdKey:  "0.1",
 	})
 	ctx, span := StartServerSpan(context.Background(), grpcCarrier(md), "service", "operation")
@@ -99,7 +99,7 @@ func TestSpan_Follow(t *testing.T) {
 	for _, test := range tests {
 		t.Run(stringx.RandId(), func(t *testing.T) {
 			md := metadata.New(map[string]string{
-				traceIdKey: "a",
+				TraceIdKey: "a",
 				spanIdKey:  test.span,
 			})
 			ctx, span := StartServerSpan(context.Background(), grpcCarrier(md),

+ 1 - 1
core/trace/spancontext.go

@@ -14,6 +14,6 @@ func (sc spanContext) SpanId() string {
 }
 
 func (sc spanContext) Visit(fn func(key, val string) bool) {
-	fn(traceIdKey, sc.traceId)
+	fn(TraceIdKey, sc.traceId)
 	fn(spanIdKey, sc.spanId)
 }

+ 2 - 0
rest/handler/tracinghandler.go

@@ -21,6 +21,8 @@ func TracingHandler(next http.Handler) http.Handler {
 		defer span.Finish()
 		r = r.WithContext(ctx)
 
+		// Conveniently track error messages
+		w.Header().Set(trace.TraceIdKey, span.TraceId())
 		next.ServeHTTP(w, r)
 	})
 }

+ 9 - 2
rest/handler/tracinghandler_test.go

@@ -1,6 +1,8 @@
 package handler
 
 import (
+	"github.com/tal-tech/go-zero/core/stringx"
+	"github.com/tal-tech/go-zero/core/trace"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -11,14 +13,19 @@ import (
 
 func TestTracingHandler(t *testing.T) {
 	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
-	req.Header.Set("X-Trace-ID", "theid")
+
+	traceId := stringx.RandId()
+	req.Header.Set(trace.TraceIdKey, traceId)
+
 	handler := TracingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		span, ok := r.Context().Value(tracespec.TracingKey).(tracespec.Trace)
 		assert.True(t, ok)
-		assert.Equal(t, "theid", span.TraceId())
+		assert.Equal(t, traceId, span.TraceId())
 	}))
 
 	resp := httptest.NewRecorder()
 	handler.ServeHTTP(resp, req)
+
 	assert.Equal(t, http.StatusOK, resp.Code)
+	assert.Equal(t, traceId, resp.Header().Get(trace.TraceIdKey))
 }