Bladeren bron

fix(rest): fix issues#2628 (#2629)

chen quan 2 jaren geleden
bovenliggende
commit
97a8b3ade5
2 gewijzigde bestanden met toevoegingen van 11 en 11 verwijderingen
  1. 3 6
      rest/handler/tracinghandler.go
  2. 8 5
      rest/handler/tracinghandler_test.go

+ 3 - 6
rest/handler/tracinghandler.go

@@ -26,20 +26,17 @@ func TracingHandler(serviceName, path string) func(http.Handler) http.Handler {
 		tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
 
 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-			defer func() {
-				next.ServeHTTP(w, r)
-			}()
-
-			ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
 			spanName := path
 			if len(spanName) == 0 {
 				spanName = r.URL.Path
 			}
 
 			if _, ok := notTracingSpans.Load(spanName); ok {
+				next.ServeHTTP(w, r)
 				return
 			}
 
+			ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
 			spanCtx, span := tracer.Start(
 				ctx,
 				spanName,
@@ -51,7 +48,7 @@ func TracingHandler(serviceName, path string) func(http.Handler) http.Handler {
 
 			// convenient for tracking error messages
 			propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
-			r = r.WithContext(spanCtx)
+			next.ServeHTTP(w, r.WithContext(spanCtx))
 		})
 	}
 }

+ 8 - 5
rest/handler/tracinghandler_test.go

@@ -27,9 +27,9 @@ func TestOtelHandler(t *testing.T) {
 		t.Run(test, func(t *testing.T) {
 			h := chain.New(TracingHandler("foo", test)).Then(
 				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-					ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header))
-					spanCtx := trace.SpanContextFromContext(ctx)
-					assert.True(t, spanCtx.IsValid())
+					span := trace.SpanFromContext(r.Context())
+					assert.True(t, span.SpanContext().IsValid())
+					assert.True(t, span.IsRecording())
 				}))
 			ts := httptest.NewServer(h)
 			defer ts.Close()
@@ -52,7 +52,7 @@ func TestOtelHandler(t *testing.T) {
 	}
 }
 
-func TestDontTracingSpanName(t *testing.T) {
+func TestDontTracingSpan(t *testing.T) {
 	ztrace.StartAgent(ztrace.Config{
 		Name:     "go-zero-test",
 		Endpoint: "http://localhost:14268/api/traces",
@@ -66,12 +66,15 @@ func TestDontTracingSpanName(t *testing.T) {
 		t.Run(test, func(t *testing.T) {
 			h := chain.New(TracingHandler("foo", test)).Then(
 				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-					spanCtx := trace.SpanContextFromContext(r.Context())
+					span := trace.SpanFromContext(r.Context())
+					spanCtx := span.SpanContext()
 					if test == "bar" {
 						assert.False(t, spanCtx.IsValid())
+						assert.False(t, span.IsRecording())
 						return
 					}
 
+					assert.True(t, span.IsRecording())
 					assert.True(t, spanCtx.IsValid())
 				}))
 			ts := httptest.NewServer(h)