浏览代码

chore: refactor (#1814)

Kevin Wan 3 年之前
父节点
当前提交
bc3c9484d1
共有 5 个文件被更改,包括 32 次插入20 次删除
  1. 4 3
      core/stores/mon/collection.go
  2. 13 6
      core/stores/mon/model.go
  3. 1 2
      core/stores/mon/model_test.go
  4. 10 7
      core/stores/redis/hook.go
  5. 4 2
      core/stores/sqlx/sqlconn.go

+ 4 - 3
core/stores/mon/collection.go

@@ -473,9 +473,10 @@ func (p keepablePromise) keep(err error) error {
 func acceptable(err error) bool {
 func acceptable(err error) bool {
 	return err == nil || err == mongo.ErrNoDocuments || err == mongo.ErrNilValue ||
 	return err == nil || err == mongo.ErrNoDocuments || err == mongo.ErrNilValue ||
 		err == mongo.ErrNilDocument || err == mongo.ErrNilCursor || err == mongo.ErrEmptySlice ||
 		err == mongo.ErrNilDocument || err == mongo.ErrNilCursor || err == mongo.ErrEmptySlice ||
-		// session err
-		err == session.ErrSessionEnded || err == session.ErrNoTransactStarted || err == session.ErrTransactInProgress ||
-		err == session.ErrAbortAfterCommit || err == session.ErrAbortTwice || err == session.ErrCommitAfterAbort ||
+		// session errors
+		err == session.ErrSessionEnded || err == session.ErrNoTransactStarted ||
+		err == session.ErrTransactInProgress || err == session.ErrAbortAfterCommit ||
+		err == session.ErrAbortTwice || err == session.ErrCommitAfterAbort ||
 		err == session.ErrUnackWCUnsupported || err == session.ErrSnapshotTransaction
 		err == session.ErrUnackWCUnsupported || err == session.ErrSnapshotTransaction
 }
 }
 
 

+ 13 - 6
core/stores/mon/model.go

@@ -21,7 +21,7 @@ type (
 		opts []Option
 		opts []Option
 	}
 	}
 
 
-	wrapSession struct {
+	wrappedSession struct {
 		mongo.Session
 		mongo.Session
 		brk breaker.Breaker
 		brk breaker.Breaker
 	}
 	}
@@ -74,7 +74,10 @@ func (m *Model) StartSession(opts ...*mopt.SessionOptions) (sess mongo.Session,
 			return sessionErr
 			return sessionErr
 		}
 		}
 
 
-		sess = &wrapSession{Session: session, brk: m.brk}
+		sess = &wrappedSession{
+			Session: session,
+			brk:     m.brk,
+		}
 
 
 		return nil
 		return nil
 	}, acceptable)
 	}, acceptable)
@@ -166,7 +169,7 @@ func (m *Model) FindOneAndUpdate(ctx context.Context, v, filter interface{}, upd
 	return res.Decode(v)
 	return res.Decode(v)
 }
 }
 
 
-func (w *wrapSession) AbortTransaction(ctx context.Context) error {
+func (w *wrappedSession) AbortTransaction(ctx context.Context) error {
 	ctx, span := startSpan(ctx)
 	ctx, span := startSpan(ctx)
 	defer span.End()
 	defer span.End()
 
 
@@ -175,7 +178,7 @@ func (w *wrapSession) AbortTransaction(ctx context.Context) error {
 	}, acceptable)
 	}, acceptable)
 }
 }
 
 
-func (w *wrapSession) CommitTransaction(ctx context.Context) error {
+func (w *wrappedSession) CommitTransaction(ctx context.Context) error {
 	ctx, span := startSpan(ctx)
 	ctx, span := startSpan(ctx)
 	defer span.End()
 	defer span.End()
 
 
@@ -184,7 +187,11 @@ func (w *wrapSession) CommitTransaction(ctx context.Context) error {
 	}, acceptable)
 	}, acceptable)
 }
 }
 
 
-func (w *wrapSession) WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*mopt.TransactionOptions) (res interface{}, err error) {
+func (w *wrappedSession) WithTransaction(
+	ctx context.Context,
+	fn func(sessCtx mongo.SessionContext) (interface{}, error),
+	opts ...*mopt.TransactionOptions,
+) (res interface{}, err error) {
 	ctx, span := startSpan(ctx)
 	ctx, span := startSpan(ctx)
 	defer span.End()
 	defer span.End()
 
 
@@ -196,7 +203,7 @@ func (w *wrapSession) WithTransaction(ctx context.Context, fn func(sessCtx mongo
 	return
 	return
 }
 }
 
 
-func (w *wrapSession) EndSession(ctx context.Context) {
+func (w *wrappedSession) EndSession(ctx context.Context) {
 	ctx, span := startSpan(ctx)
 	ctx, span := startSpan(ctx)
 	defer span.End()
 	defer span.End()
 
 

+ 1 - 2
core/stores/mon/model_test.go

@@ -18,6 +18,7 @@ func TestModel_StartSession(t *testing.T) {
 		m := createModel(mt)
 		m := createModel(mt)
 		sess, err := m.StartSession()
 		sess, err := m.StartSession()
 		assert.Nil(t, err)
 		assert.Nil(t, err)
+		defer sess.EndSession(context.Background())
 
 
 		_, err = sess.WithTransaction(context.Background(), func(sessCtx mongo.SessionContext) (interface{}, error) {
 		_, err = sess.WithTransaction(context.Background(), func(sessCtx mongo.SessionContext) (interface{}, error) {
 			_ = sessCtx.StartTransaction()
 			_ = sessCtx.StartTransaction()
@@ -26,10 +27,8 @@ func TestModel_StartSession(t *testing.T) {
 			return nil, nil
 			return nil, nil
 		})
 		})
 		assert.Nil(t, err)
 		assert.Nil(t, err)
-
 		assert.NoError(t, sess.CommitTransaction(context.Background()))
 		assert.NoError(t, sess.CommitTransaction(context.Background()))
 		assert.Error(t, sess.AbortTransaction(context.Background()))
 		assert.Error(t, sess.AbortTransaction(context.Background()))
-		sess.EndSession(context.Background())
 	})
 	})
 }
 }
 
 

+ 10 - 7
core/stores/redis/hook.go

@@ -14,6 +14,9 @@ import (
 	tracestd "go.opentelemetry.io/otel/trace"
 	tracestd "go.opentelemetry.io/otel/trace"
 )
 )
 
 
+// spanName is the span name of the redis calls.
+const spanName = "redis"
+
 var (
 var (
 	startTimeKey = contextKey("startTime")
 	startTimeKey = contextKey("startTime")
 	spanKey      = contextKey("span")
 	spanKey      = contextKey("span")
@@ -28,11 +31,11 @@ type (
 )
 )
 
 
 func (h hook) BeforeProcess(ctx context.Context, _ red.Cmder) (context.Context, error) {
 func (h hook) BeforeProcess(ctx context.Context, _ red.Cmder) (context.Context, error) {
-	return h.spanStart(context.WithValue(ctx, startTimeKey, timex.Now())), nil
+	return h.startSpan(context.WithValue(ctx, startTimeKey, timex.Now())), nil
 }
 }
 
 
 func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error {
 func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error {
-	h.spanEnd(ctx)
+	h.endSpan(ctx)
 
 
 	val := ctx.Value(startTimeKey)
 	val := ctx.Value(startTimeKey)
 	if val == nil {
 	if val == nil {
@@ -53,11 +56,11 @@ func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error {
 }
 }
 
 
 func (h hook) BeforeProcessPipeline(ctx context.Context, _ []red.Cmder) (context.Context, error) {
 func (h hook) BeforeProcessPipeline(ctx context.Context, _ []red.Cmder) (context.Context, error) {
-	return h.spanStart(context.WithValue(ctx, startTimeKey, timex.Now())), nil
+	return h.startSpan(context.WithValue(ctx, startTimeKey, timex.Now())), nil
 }
 }
 
 
 func (h hook) AfterProcessPipeline(ctx context.Context, cmds []red.Cmder) error {
 func (h hook) AfterProcessPipeline(ctx context.Context, cmds []red.Cmder) error {
-	h.spanEnd(ctx)
+	h.endSpan(ctx)
 
 
 	if len(cmds) == 0 {
 	if len(cmds) == 0 {
 		return nil
 		return nil
@@ -92,12 +95,12 @@ func logDuration(ctx context.Context, cmd red.Cmder, duration time.Duration) {
 	logx.WithContext(ctx).WithDuration(duration).Slowf("[REDIS] slowcall on executing: %s", buf.String())
 	logx.WithContext(ctx).WithDuration(duration).Slowf("[REDIS] slowcall on executing: %s", buf.String())
 }
 }
 
 
-func (h hook) spanStart(ctx context.Context) context.Context {
-	ctx, span := h.tracer.Start(ctx, "redis")
+func (h hook) startSpan(ctx context.Context) context.Context {
+	ctx, span := h.tracer.Start(ctx, spanName)
 	return context.WithValue(ctx, spanKey, span)
 	return context.WithValue(ctx, spanKey, span)
 }
 }
 
 
-func (h hook) spanEnd(ctx context.Context) {
+func (h hook) endSpan(ctx context.Context) {
 	spanVal := ctx.Value(spanKey)
 	spanVal := ctx.Value(spanKey)
 	if spanVal == nil {
 	if spanVal == nil {
 		return
 		return

+ 4 - 2
core/stores/sqlx/sqlconn.go

@@ -11,6 +11,9 @@ import (
 	tracesdk "go.opentelemetry.io/otel/trace"
 	tracesdk "go.opentelemetry.io/otel/trace"
 )
 )
 
 
+// spanName is used to identify the span name for the SQL execution.
+const spanName = "sql"
+
 // ErrNotFound is an alias of sql.ErrNoRows
 // ErrNotFound is an alias of sql.ErrNoRows
 var ErrNotFound = sql.ErrNoRows
 var ErrNotFound = sql.ErrNoRows
 
 
@@ -240,7 +243,6 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{},
 	return db.queryRows(ctx, func(rows *sql.Rows) error {
 	return db.queryRows(ctx, func(rows *sql.Rows) error {
 		return unmarshalRows(v, rows, false)
 		return unmarshalRows(v, rows, false)
 	}, q, args...)
 	}, q, args...)
-
 }
 }
 
 
 func (db *commonSqlConn) RawDB() (*sql.DB, error) {
 func (db *commonSqlConn) RawDB() (*sql.DB, error) {
@@ -362,5 +364,5 @@ func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args
 
 
 func startSpan(ctx context.Context) (context.Context, tracesdk.Span) {
 func startSpan(ctx context.Context) (context.Context, tracesdk.Span) {
 	tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
 	tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
-	return tracer.Start(ctx, "sql")
+	return tracer.Start(ctx, spanName)
 }
 }