瀏覽代碼

chore: optimize code (#1818)

Signed-off-by: chenquan <chenquan.dev@gmail.com>
chen quan 3 年之前
父節點
當前提交
22b157bb6c
共有 3 個文件被更改,包括 14 次插入17 次删除
  1. 6 2
      core/stores/mon/collection.go
  2. 3 11
      core/stores/redis/hook.go
  3. 5 4
      core/stores/redis/hook_test.go

+ 6 - 2
core/stores/mon/collection.go

@@ -16,7 +16,11 @@ import (
 	tracesdk "go.opentelemetry.io/otel/trace"
 )
 
-const defaultSlowThreshold = time.Millisecond * 500
+const (
+	defaultSlowThreshold = time.Millisecond * 500
+	// spanName is the span name of the mongo calls.
+	spanName = "mongo"
+)
 
 // ErrNotFound is an alias of mongo.ErrNoDocuments
 var ErrNotFound = mongo.ErrNoDocuments
@@ -482,5 +486,5 @@ func acceptable(err error) bool {
 
 func startSpan(ctx context.Context) (context.Context, tracesdk.Span) {
 	tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
-	return tracer.Start(ctx, "mongo")
+	return tracer.Start(ctx, spanName)
 }

+ 3 - 11
core/stores/redis/hook.go

@@ -19,7 +19,6 @@ const spanName = "redis"
 
 var (
 	startTimeKey = contextKey("startTime")
-	spanKey      = contextKey("span")
 	durationHook = hook{tracer: otel.GetTracerProvider().Tracer(trace.TraceName)}
 )
 
@@ -96,17 +95,10 @@ func logDuration(ctx context.Context, cmd red.Cmder, duration time.Duration) {
 }
 
 func (h hook) startSpan(ctx context.Context) context.Context {
-	ctx, span := h.tracer.Start(ctx, spanName)
-	return context.WithValue(ctx, spanKey, span)
+	ctx, _ = h.tracer.Start(ctx, spanName)
+	return ctx
 }
 
 func (h hook) endSpan(ctx context.Context) {
-	spanVal := ctx.Value(spanKey)
-	if spanVal == nil {
-		return
-	}
-
-	if span, ok := spanVal.(tracestd.Span); ok {
-		span.End()
-	}
+	tracestd.SpanFromContext(ctx).End()
 }

+ 5 - 4
core/stores/redis/hook_test.go

@@ -10,6 +10,7 @@ import (
 	red "github.com/go-redis/redis/v8"
 	"github.com/stretchr/testify/assert"
 	ztrace "github.com/zeromicro/go-zero/core/trace"
+	tracesdk "go.opentelemetry.io/otel/trace"
 )
 
 func TestHookProcessCase1(t *testing.T) {
@@ -32,7 +33,7 @@ func TestHookProcessCase1(t *testing.T) {
 
 	assert.Nil(t, durationHook.AfterProcess(ctx, red.NewCmd(context.Background())))
 	assert.False(t, strings.Contains(buf.String(), "slow"))
-	assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name())
+	assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
 }
 
 func TestHookProcessCase2(t *testing.T) {
@@ -52,7 +53,7 @@ func TestHookProcessCase2(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name())
+	assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
 
 	time.Sleep(slowThreshold.Load() + time.Millisecond)
 
@@ -93,7 +94,7 @@ func TestHookProcessPipelineCase1(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name())
+	assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
 
 	assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
 		red.NewCmd(context.Background()),
@@ -118,7 +119,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name())
+	assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
 
 	time.Sleep(slowThreshold.Load() + time.Millisecond)