Kevin Wan 2 tahun lalu
induk
melakukan
8ed22eafdd
2 mengubah file dengan 126 tambahan dan 45 penghapusan
  1. 85 45
      core/stores/sqlx/stmt.go
  2. 41 0
      core/stores/sqlx/stmt_test.go

+ 85 - 45
core/stores/sqlx/stmt.go

@@ -12,7 +12,22 @@ import (
 
 const defaultSlowThreshold = time.Millisecond * 500
 
-var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
+var (
+	slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
+	logSql        = syncx.ForAtomicBool(true)
+	logSlowSql    = syncx.ForAtomicBool(true)
+)
+
+// DisableLog disables logging of sql statements, includes info and slow logs.
+func DisableLog() {
+	logSql.Set(false)
+	logSlowSql.Set(false)
+}
+
+// DisableStmtLog disables info logging of sql statements, but keeps slow logs.
+func DisableStmtLog() {
+	logSql.Set(false)
+}
 
 // SetSlowThreshold sets the slow threshold.
 func SetSlowThreshold(threshold time.Duration) {
@@ -20,64 +35,39 @@ func SetSlowThreshold(threshold time.Duration) {
 }
 
 func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
-	stmt, err := format(q, args...)
-	if err != nil {
+	guard := newGuard("exec")
+	if err := guard.start(q, args...); err != nil {
 		return nil, err
 	}
 
-	startTime := timex.Now()
 	result, err := conn.ExecContext(ctx, q, args...)
-	duration := timex.Since(startTime)
-	if duration > slowThreshold.Load() {
-		logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
-	} else {
-		logx.WithContext(ctx).WithDuration(duration).Infof("sql exec: %s", stmt)
-	}
-	if err != nil {
-		logSqlError(ctx, stmt, err)
-	}
+	guard.finish(ctx, err)
 
 	return result, err
 }
 
 func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
-	stmt, err := format(q, args...)
-	if err != nil {
+	guard := newGuard("execStmt")
+	if err := guard.start(q, args...); err != nil {
 		return nil, err
 	}
 
-	startTime := timex.Now()
 	result, err := conn.ExecContext(ctx, args...)
-	duration := timex.Since(startTime)
-	if duration > slowThreshold.Load() {
-		logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
-	} else {
-		logx.WithContext(ctx).WithDuration(duration).Infof("sql execStmt: %s", stmt)
-	}
-	if err != nil {
-		logSqlError(ctx, stmt, err)
-	}
+	guard.finish(ctx, err)
 
 	return result, err
 }
 
 func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error,
 	q string, args ...interface{}) error {
-	stmt, err := format(q, args...)
-	if err != nil {
+	guard := newGuard("query")
+	if err := guard.start(q, args...); err != nil {
 		return err
 	}
 
-	startTime := timex.Now()
 	rows, err := conn.QueryContext(ctx, q, args...)
-	duration := timex.Since(startTime)
-	if duration > slowThreshold.Load() {
-		logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
-	} else {
-		logx.WithContext(ctx).WithDuration(duration).Infof("sql query: %s", stmt)
-	}
+	guard.finish(ctx, err)
 	if err != nil {
-		logSqlError(ctx, stmt, err)
 		return err
 	}
 	defer rows.Close()
@@ -87,24 +77,74 @@ func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error,
 
 func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error,
 	q string, args ...interface{}) error {
-	stmt, err := format(q, args...)
-	if err != nil {
+	guard := newGuard("queryStmt")
+	if err := guard.start(q, args...); err != nil {
 		return err
 	}
 
-	startTime := timex.Now()
 	rows, err := conn.QueryContext(ctx, args...)
-	duration := timex.Since(startTime)
-	if duration > slowThreshold.Load() {
-		logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
-	} else {
-		logx.WithContext(ctx).WithDuration(duration).Infof("sql queryStmt: %s", stmt)
-	}
+	guard.finish(ctx, err)
 	if err != nil {
-		logSqlError(ctx, stmt, err)
 		return err
 	}
 	defer rows.Close()
 
 	return scanner(rows)
 }
+
+type (
+	sqlGuard interface {
+		start(q string, args ...interface{}) error
+		finish(ctx context.Context, err error)
+	}
+
+	nilGuard struct{}
+
+	realSqlGuard struct {
+		command   string
+		stmt      string
+		startTime time.Duration
+	}
+)
+
+func newGuard(command string) sqlGuard {
+	if logSql.True() || logSlowSql.True() {
+		return &realSqlGuard{
+			command: command,
+		}
+	}
+
+	return nilGuard{}
+}
+
+func (n nilGuard) start(_ string, _ ...interface{}) error {
+	return nil
+}
+
+func (n nilGuard) finish(_ context.Context, _ error) {
+}
+
+func (e *realSqlGuard) finish(ctx context.Context, err error) {
+	duration := timex.Since(e.startTime)
+	if duration > slowThreshold.Load() {
+		logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] %s: slowcall - %s", e.command, e.stmt)
+	} else if logSql.True() {
+		logx.WithContext(ctx).WithDuration(duration).Infof("sql %s: %s", e.command, e.stmt)
+	}
+
+	if err != nil {
+		logSqlError(ctx, e.stmt, err)
+	}
+}
+
+func (e *realSqlGuard) start(q string, args ...interface{}) error {
+	stmt, err := format(q, args...)
+	if err != nil {
+		return err
+	}
+
+	e.stmt = stmt
+	e.startTime = timex.Now()
+
+	return nil
+}

+ 41 - 0
core/stores/sqlx/stmt_test.go

@@ -178,6 +178,47 @@ func TestSetSlowThreshold(t *testing.T) {
 	assert.Equal(t, time.Second, slowThreshold.Load())
 }
 
+func TestDisableLog(t *testing.T) {
+	assert.True(t, logSql.True())
+	assert.True(t, logSlowSql.True())
+	defer func() {
+		logSql.Set(true)
+		logSlowSql.Set(true)
+	}()
+
+	DisableLog()
+	assert.False(t, logSql.True())
+	assert.False(t, logSlowSql.True())
+}
+
+func TestDisableStmtLog(t *testing.T) {
+	assert.True(t, logSql.True())
+	assert.True(t, logSlowSql.True())
+	defer func() {
+		logSql.Set(true)
+		logSlowSql.Set(true)
+	}()
+
+	DisableStmtLog()
+	assert.False(t, logSql.True())
+	assert.True(t, logSlowSql.True())
+}
+
+func TestNilGuard(t *testing.T) {
+	assert.True(t, logSql.True())
+	assert.True(t, logSlowSql.True())
+	defer func() {
+		logSql.Set(true)
+		logSlowSql.Set(true)
+	}()
+
+	DisableLog()
+	guard := newGuard("any")
+	assert.Nil(t, guard.start("foo", "bar"))
+	guard.finish(context.Background(), nil)
+	assert.Equal(t, nilGuard{}, guard)
+}
+
 type mockedSessionConn struct {
 	lastInsertId int64
 	rowsAffected int64