浏览代码

fix: fixed the bug that old trace instances may be fetched

chenquan 2 年之前
父节点
当前提交
3bc40d9eaf
共有 6 个文件被更改,包括 77 次插入12 次删除
  1. 2 2
      core/stores/mon/trace.go
  2. 6 6
      core/stores/redis/hook.go
  3. 2 2
      core/stores/sqlx/trace.go
  4. 13 0
      core/trace/utils.go
  5. 51 0
      core/trace/utils_test.go
  6. 3 2
      rest/httpc/requests.go

+ 2 - 2
core/stores/mon/trace.go

@@ -5,7 +5,6 @@ import (
 
 	"github.com/zeromicro/go-zero/core/trace"
 	"go.mongodb.org/mongo-driver/mongo"
-	"go.opentelemetry.io/otel"
 	"go.opentelemetry.io/otel/attribute"
 	"go.opentelemetry.io/otel/codes"
 	oteltrace "go.opentelemetry.io/otel/trace"
@@ -14,7 +13,8 @@ import (
 var mongoCmdAttributeKey = attribute.Key("mongo.cmd")
 
 func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span) {
-	tracer := otel.Tracer(trace.TraceName)
+	tracer := trace.TracerFromContext(ctx)
+
 	ctx, span := tracer.Start(ctx,
 		spanName,
 		oteltrace.WithSpanKind(oteltrace.SpanKindClient),

+ 6 - 6
core/stores/redis/hook.go

@@ -14,7 +14,7 @@ import (
 	"github.com/zeromicro/go-zero/core/mapping"
 	"github.com/zeromicro/go-zero/core/timex"
 	"github.com/zeromicro/go-zero/core/trace"
-	"go.opentelemetry.io/otel"
+
 	"go.opentelemetry.io/otel/attribute"
 	"go.opentelemetry.io/otel/codes"
 	oteltrace "go.opentelemetry.io/otel/trace"
@@ -25,15 +25,13 @@ const spanName = "redis"
 
 var (
 	startTimeKey          = contextKey("startTime")
-	durationHook          = hook{tracer: otel.Tracer(trace.TraceName)}
+	durationHook          = hook{}
 	redisCmdsAttributeKey = attribute.Key("redis.cmds")
 )
 
 type (
 	contextKey string
-	hook       struct {
-		tracer oteltrace.Tracer
-	}
+	hook       struct{}
 )
 
 func (h hook) BeforeProcess(ctx context.Context, cmd red.Cmder) (context.Context, error) {
@@ -155,7 +153,9 @@ func logDuration(ctx context.Context, cmds []red.Cmder, duration time.Duration)
 }
 
 func (h hook) startSpan(ctx context.Context, cmds ...red.Cmder) context.Context {
-	ctx, span := h.tracer.Start(ctx,
+	tracer := trace.TracerFromContext(ctx)
+
+	ctx, span := tracer.Start(ctx,
 		spanName,
 		oteltrace.WithSpanKind(oteltrace.SpanKindClient),
 	)

+ 2 - 2
core/stores/sqlx/trace.go

@@ -5,7 +5,6 @@ import (
 	"database/sql"
 
 	"github.com/zeromicro/go-zero/core/trace"
-	"go.opentelemetry.io/otel"
 	"go.opentelemetry.io/otel/attribute"
 	"go.opentelemetry.io/otel/codes"
 	oteltrace "go.opentelemetry.io/otel/trace"
@@ -14,7 +13,8 @@ import (
 var sqlAttributeKey = attribute.Key("sql.method")
 
 func startSpan(ctx context.Context, method string) (context.Context, oteltrace.Span) {
-	tracer := otel.Tracer(trace.TraceName)
+	tracer := trace.TracerFromContext(ctx)
+
 	start, span := tracer.Start(ctx,
 		spanName,
 		oteltrace.WithSpanKind(oteltrace.SpanKindClient),

+ 13 - 0
core/trace/utils.go

@@ -6,8 +6,10 @@ import (
 	"strings"
 
 	ztrace "github.com/zeromicro/go-zero/internal/trace"
+	"go.opentelemetry.io/otel"
 	"go.opentelemetry.io/otel/attribute"
 	semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
+	"go.opentelemetry.io/otel/trace"
 	"google.golang.org/grpc/peer"
 )
 
@@ -75,3 +77,14 @@ func PeerAttr(addr string) []attribute.KeyValue {
 		semconv.NetPeerPortKey.String(port),
 	}
 }
+
+// TracerFromContext returns a tracer in ctx, otherwise returns a global tracer.
+func TracerFromContext(ctx context.Context) (tracer trace.Tracer) {
+	if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
+		tracer = span.TracerProvider().Tracer(TraceName)
+	} else {
+		tracer = otel.Tracer(TraceName)
+	}
+
+	return
+}

+ 51 - 0
core/trace/utils_test.go

@@ -6,8 +6,12 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"go.opentelemetry.io/otel"
 	"go.opentelemetry.io/otel/attribute"
+	"go.opentelemetry.io/otel/sdk/resource"
+	sdktrace "go.opentelemetry.io/otel/sdk/trace"
 	semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
+	"go.opentelemetry.io/otel/trace"
 	"google.golang.org/grpc/peer"
 )
 
@@ -151,3 +155,50 @@ func TestPeerAttr(t *testing.T) {
 		})
 	}
 }
+
+func TestTracerFromContext(t *testing.T) {
+	traceFn := func(ctx context.Context, hasTraceId bool) {
+		spanContext := trace.SpanContextFromContext(ctx)
+		assert.Equal(t, spanContext.IsValid(), hasTraceId)
+		parentTraceId := spanContext.TraceID().String()
+
+		tracer := TracerFromContext(ctx)
+		_, span := tracer.Start(ctx, "b")
+		defer span.End()
+
+		spanContext = span.SpanContext()
+		assert.True(t, spanContext.IsValid())
+		if hasTraceId {
+			assert.Equal(t, parentTraceId, spanContext.TraceID().String())
+		}
+
+	}
+
+	t.Run("context", func(t *testing.T) {
+		opts := []sdktrace.TracerProviderOption{
+			// Set the sampling rate based on the parent span to 100%
+			sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(1))),
+			// Record information about this application in a Resource.
+			sdktrace.WithResource(resource.NewSchemaless(semconv.ServiceNameKey.String("test"))),
+		}
+		tp = sdktrace.NewTracerProvider(opts...)
+		otel.SetTracerProvider(tp)
+		ctx, span := tp.Tracer(TraceName).Start(context.Background(), "a")
+
+		defer span.End()
+		traceFn(ctx, true)
+	})
+
+	t.Run("global", func(t *testing.T) {
+		opts := []sdktrace.TracerProviderOption{
+			// Set the sampling rate based on the parent span to 100%
+			sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(1))),
+			// Record information about this application in a Resource.
+			sdktrace.WithResource(resource.NewSchemaless(semconv.ServiceNameKey.String("test"))),
+		}
+		tp = sdktrace.NewTracerProvider(opts...)
+		otel.SetTracerProvider(tp)
+
+		traceFn(context.Background(), false)
+	})
+}

+ 3 - 2
rest/httpc/requests.go

@@ -156,12 +156,13 @@ func fillPath(u *nurl.URL, val map[string]any) error {
 }
 
 func request(r *http.Request, cli client) (*http.Response, error) {
-	tracer := otel.Tracer(trace.TraceName)
+	ctx := r.Context()
+	tracer := trace.TracerFromContext(ctx)
 	propagator := otel.GetTextMapPropagator()
 
 	spanName := r.URL.Path
 	ctx, span := tracer.Start(
-		r.Context(),
+		ctx,
 		spanName,
 		oteltrace.WithSpanKind(oteltrace.SpanKindClient),
 		oteltrace.WithAttributes(semconv.HTTPClientAttributesFromHTTPRequest(r)...),