Quellcode durchsuchen

feat: support using session to execute statements in transaction (#3252)

Kevin Wan vor 1 Jahr
Ursprung
Commit
bff5b81ad9

+ 2 - 2
.gitignore

@@ -14,9 +14,10 @@
 **/.idea
 **/.DS_Store
 **/logs
+**/adhoc
+**/coverage.txt
 
 # for test purpose
-**/adhoc
 go.work
 go.work.sum
 
@@ -27,4 +28,3 @@ go.work.sum
 # vim auto backup file
 *~
 !OWNERS
-coverage.txt

+ 12 - 0
core/stores/sqlc/cachedsql.go

@@ -226,3 +226,15 @@ func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
 func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
 	return cc.db.TransactCtx(ctx, fn)
 }
+
+// WithSession returns a new CachedConn with given session.
+// If query from session, the uncommitted data might be returned.
+// Don't query for the uncommitted data, you should just use it,
+// and don't use the cache for the uncommitted data.
+// Not recommend to use cache within transactions due to consistency problem.
+func (cc CachedConn) WithSession(session sqlx.Session) CachedConn {
+	return CachedConn{
+		db:    sqlx.NewSqlConnFromSession(session),
+		cache: cc.cache,
+	}
+}

+ 152 - 12
core/stores/sqlc/cachedsql_test.go

@@ -15,6 +15,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/DATA-DOG/go-sqlmock"
 	"github.com/alicebob/miniredis/v2"
 	"github.com/stretchr/testify/assert"
 	"github.com/zeromicro/go-zero/core/fx"
@@ -24,6 +25,8 @@ import (
 	"github.com/zeromicro/go-zero/core/stores/redis"
 	"github.com/zeromicro/go-zero/core/stores/redis/redistest"
 	"github.com/zeromicro/go-zero/core/stores/sqlx"
+	"github.com/zeromicro/go-zero/core/syncx"
+	"github.com/zeromicro/go-zero/internal/dbtest"
 )
 
 func init() {
@@ -39,7 +42,7 @@ func TestCachedConn_GetCache(t *testing.T) {
 	var value string
 	err := c.GetCache("any", &value)
 	assert.Equal(t, ErrNotFound, err)
-	r.Set("any", `"value"`)
+	_ = r.Set("any", `"value"`)
 	err = c.GetCache("any", &value)
 	assert.Nil(t, err)
 	assert.Equal(t, "value", value)
@@ -368,6 +371,24 @@ func TestStatFromMemory(t *testing.T) {
 	assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
 }
 
+func TestCachedConn_DelCache(t *testing.T) {
+	r := redistest.CreateRedis(t)
+
+	const (
+		key   = "user"
+		value = "any"
+	)
+	assert.NoError(t, r.Set(key, value))
+
+	c := NewNodeConn(&trackedConn{}, r, cache.WithExpiry(time.Second*30))
+	err := c.DelCache(key)
+	assert.Nil(t, err)
+
+	val, err := r.Get(key)
+	assert.Nil(t, err)
+	assert.Empty(t, val)
+}
+
 func TestCachedConnQueryRow(t *testing.T) {
 	r := redistest.CreateRedis(t)
 
@@ -543,6 +564,125 @@ func TestNewConnWithCache(t *testing.T) {
 	assert.True(t, conn.execValue)
 }
 
+func TestCachedConn_WithSession(t *testing.T) {
+	dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
+		mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
+
+		r := redistest.CreateRedis(t)
+		conn := CachedConn{
+			cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
+		}
+		conn = conn.WithSession(sqlx.NewSessionFromTx(tx))
+		res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
+			return conn.Exec("any")
+		}, "foo")
+		assert.NoError(t, err)
+		last, err := res.LastInsertId()
+		assert.NoError(t, err)
+		assert.Equal(t, int64(2), last)
+		affected, err := res.RowsAffected()
+		assert.NoError(t, err)
+		assert.Equal(t, int64(3), affected)
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
+		mock.ExpectCommit()
+
+		r := redistest.CreateRedis(t)
+		conn := CachedConn{
+			db:    sqlx.NewSqlConnFromDB(db),
+			cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
+		}
+		assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
+			conn = conn.WithSession(session)
+			res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
+				return conn.Exec("any")
+			}, "foo")
+			assert.NoError(t, err)
+			last, err := res.LastInsertId()
+			assert.NoError(t, err)
+			assert.Equal(t, int64(2), last)
+			affected, err := res.RowsAffected()
+			assert.NoError(t, err)
+			assert.Equal(t, int64(3), affected)
+			return nil
+		}))
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		mock.ExpectExec("any").WillReturnError(errors.New("foo"))
+		mock.ExpectRollback()
+
+		r := redistest.CreateRedis(t)
+		conn := CachedConn{
+			db:    sqlx.NewSqlConnFromDB(db),
+			cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
+		}
+		assert.Error(t, conn.Transact(func(session sqlx.Session) error {
+			conn = conn.WithSession(session)
+			_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
+				return conn.Exec("any")
+			}, "bar")
+			return err
+		}))
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
+		mock.ExpectCommit()
+
+		r := redistest.CreateRedis(t)
+		conn := CachedConn{
+			db:    sqlx.NewSqlConnFromDB(db),
+			cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
+		}
+		assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
+			var val string
+			conn = conn.WithSession(session)
+			err := conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
+				return conn.QueryRow(v, "any")
+			})
+			assert.Equal(t, "2", val)
+			return err
+		}))
+		val, err := r.Get("foo")
+		assert.NoError(t, err)
+		assert.Equal(t, `"2"`, val)
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
+		mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
+		mock.ExpectCommit()
+
+		r := redistest.CreateRedis(t)
+		conn := CachedConn{
+			db:    sqlx.NewSqlConnFromDB(db),
+			cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
+		}
+		assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
+			var val string
+			conn = conn.WithSession(session)
+			assert.NoError(t, conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
+				return conn.QueryRow(v, "any")
+			}))
+			assert.Equal(t, "2", val)
+			_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
+				return conn.Exec("any")
+			}, "foo")
+			return err
+		}))
+		val, err := r.Get("foo")
+		assert.NoError(t, err)
+		assert.Empty(t, val)
+	})
+}
+
 func resetStats() {
 	atomic.StoreUint64(&stats.Total, 0)
 	atomic.StoreUint64(&stats.Hit, 0)
@@ -554,35 +694,35 @@ type dummySqlConn struct {
 	queryRow func(any, string, ...any) error
 }
 
-func (d dummySqlConn) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
+func (d dummySqlConn) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
 	return nil, nil
 }
 
-func (d dummySqlConn) PrepareCtx(ctx context.Context, query string) (sqlx.StmtSession, error) {
+func (d dummySqlConn) PrepareCtx(_ context.Context, _ string) (sqlx.StmtSession, error) {
 	return nil, nil
 }
 
-func (d dummySqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
+func (d dummySqlConn) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (d dummySqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
+func (d dummySqlConn) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (d dummySqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
+func (d dummySqlConn) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (d dummySqlConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
+func (d dummySqlConn) TransactCtx(_ context.Context, _ func(context.Context, sqlx.Session) error) error {
 	return nil
 }
 
-func (d dummySqlConn) Exec(query string, args ...any) (sql.Result, error) {
+func (d dummySqlConn) Exec(_ string, _ ...any) (sql.Result, error) {
 	return nil, nil
 }
 
-func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) {
+func (d dummySqlConn) Prepare(_ string) (sqlx.StmtSession, error) {
 	return nil, nil
 }
 
@@ -597,15 +737,15 @@ func (d dummySqlConn) QueryRowCtx(_ context.Context, v any, query string, args .
 	return nil
 }
 
-func (d dummySqlConn) QueryRowPartial(v any, query string, args ...any) error {
+func (d dummySqlConn) QueryRowPartial(_ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (d dummySqlConn) QueryRows(v any, query string, args ...any) error {
+func (d dummySqlConn) QueryRows(_ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (d dummySqlConn) QueryRowsPartial(v any, query string, args ...any) error {
+func (d dummySqlConn) QueryRowsPartial(_ any, _ string, _ ...any) error {
 	return nil
 }
 

+ 4 - 20
core/stores/sqlx/bulkinserter_test.go

@@ -9,7 +9,7 @@ import (
 
 	"github.com/DATA-DOG/go-sqlmock"
 	"github.com/stretchr/testify/assert"
-	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/internal/dbtest"
 )
 
 type mockedConn struct {
@@ -81,7 +81,7 @@ func (c *mockedConn) Transact(func(session Session) error) error {
 }
 
 func TestBulkInserter(t *testing.T) {
-	runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		var conn mockedConn
 		inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
 		assert.Nil(t, err)
@@ -98,7 +98,7 @@ func TestBulkInserter(t *testing.T) {
 }
 
 func TestBulkInserterSuffix(t *testing.T) {
-	runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		var conn mockedConn
 		inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
 			`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
@@ -119,7 +119,7 @@ func TestBulkInserterSuffix(t *testing.T) {
 }
 
 func TestBulkInserterBadStatement(t *testing.T) {
-	runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		var conn mockedConn
 		_, err := NewBulkInserter(&conn, "foo")
 		assert.NotNil(t, err)
@@ -144,19 +144,3 @@ func TestBulkInserter_Update(t *testing.T) {
 	assert.NotNil(t, inserter.UpdateStmt("foo"))
 	assert.NotNil(t, inserter.Insert("foo", "bar"))
 }
-
-func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
-	logx.Disable()
-
-	db, mock, err := sqlmock.New()
-	if err != nil {
-		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
-	}
-	defer db.Close()
-
-	fn(db, mock)
-
-	if err := mock.ExpectationsWereMet(); err != nil {
-		t.Errorf("there were unfulfilled expectations: %s", err)
-	}
-}

+ 14 - 0
core/stores/sqlx/errors.go

@@ -0,0 +1,14 @@
+package sqlx
+
+import (
+	"database/sql"
+	"errors"
+)
+
+var (
+	// ErrNotFound is an alias of sql.ErrNoRows
+	ErrNotFound = sql.ErrNoRows
+
+	errCantNestTx    = errors.New("cannot nest transactions")
+	errNoRawDBFromTx = errors.New("cannot get raw db from transaction")
+)

+ 60 - 76
core/stores/sqlx/orm_test.go

@@ -8,11 +8,11 @@ import (
 
 	"github.com/DATA-DOG/go-sqlmock"
 	"github.com/stretchr/testify/assert"
-	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/internal/dbtest"
 )
 
 func TestUnmarshalRowBool(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -25,7 +25,7 @@ func TestUnmarshalRowBool(t *testing.T) {
 }
 
 func TestUnmarshalRowBoolNotSettable(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -37,7 +37,7 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) {
 }
 
 func TestUnmarshalRowInt(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -50,7 +50,7 @@ func TestUnmarshalRowInt(t *testing.T) {
 }
 
 func TestUnmarshalRowInt8(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -63,7 +63,7 @@ func TestUnmarshalRowInt8(t *testing.T) {
 }
 
 func TestUnmarshalRowInt16(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -76,7 +76,7 @@ func TestUnmarshalRowInt16(t *testing.T) {
 }
 
 func TestUnmarshalRowInt32(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -89,7 +89,7 @@ func TestUnmarshalRowInt32(t *testing.T) {
 }
 
 func TestUnmarshalRowInt64(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -102,7 +102,7 @@ func TestUnmarshalRowInt64(t *testing.T) {
 }
 
 func TestUnmarshalRowUint(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -115,7 +115,7 @@ func TestUnmarshalRowUint(t *testing.T) {
 }
 
 func TestUnmarshalRowUint8(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -128,7 +128,7 @@ func TestUnmarshalRowUint8(t *testing.T) {
 }
 
 func TestUnmarshalRowUint16(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -141,7 +141,7 @@ func TestUnmarshalRowUint16(t *testing.T) {
 }
 
 func TestUnmarshalRowUint32(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -154,7 +154,7 @@ func TestUnmarshalRowUint32(t *testing.T) {
 }
 
 func TestUnmarshalRowUint64(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -167,7 +167,7 @@ func TestUnmarshalRowUint64(t *testing.T) {
 }
 
 func TestUnmarshalRowFloat32(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -180,7 +180,7 @@ func TestUnmarshalRowFloat32(t *testing.T) {
 }
 
 func TestUnmarshalRowFloat64(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -193,7 +193,7 @@ func TestUnmarshalRowFloat64(t *testing.T) {
 }
 
 func TestUnmarshalRowString(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		const expect = "hello"
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -212,7 +212,7 @@ func TestUnmarshalRowStruct(t *testing.T) {
 		Age  int
 	})
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -230,7 +230,7 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
 		Name string `db:"name"`
 	})
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -248,7 +248,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
 		Name string `db:"name"`
 	})
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -259,7 +259,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
 }
 
 func TestUnmarshalRowsBool(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []bool{true, false}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -273,7 +273,7 @@ func TestUnmarshalRowsBool(t *testing.T) {
 }
 
 func TestUnmarshalRowsInt(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []int{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -287,7 +287,7 @@ func TestUnmarshalRowsInt(t *testing.T) {
 }
 
 func TestUnmarshalRowsInt8(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []int8{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -301,7 +301,7 @@ func TestUnmarshalRowsInt8(t *testing.T) {
 }
 
 func TestUnmarshalRowsInt16(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []int16{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -315,7 +315,7 @@ func TestUnmarshalRowsInt16(t *testing.T) {
 }
 
 func TestUnmarshalRowsInt32(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []int32{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -329,7 +329,7 @@ func TestUnmarshalRowsInt32(t *testing.T) {
 }
 
 func TestUnmarshalRowsInt64(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []int64{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -343,7 +343,7 @@ func TestUnmarshalRowsInt64(t *testing.T) {
 }
 
 func TestUnmarshalRowsUint(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []uint{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -357,7 +357,7 @@ func TestUnmarshalRowsUint(t *testing.T) {
 }
 
 func TestUnmarshalRowsUint8(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []uint8{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -371,7 +371,7 @@ func TestUnmarshalRowsUint8(t *testing.T) {
 }
 
 func TestUnmarshalRowsUint16(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []uint16{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -385,7 +385,7 @@ func TestUnmarshalRowsUint16(t *testing.T) {
 }
 
 func TestUnmarshalRowsUint32(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []uint32{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -399,7 +399,7 @@ func TestUnmarshalRowsUint32(t *testing.T) {
 }
 
 func TestUnmarshalRowsUint64(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []uint64{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -413,7 +413,7 @@ func TestUnmarshalRowsUint64(t *testing.T) {
 }
 
 func TestUnmarshalRowsFloat32(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []float32{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -427,7 +427,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
 }
 
 func TestUnmarshalRowsFloat64(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []float64{2, 3}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -441,7 +441,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
 }
 
 func TestUnmarshalRowsString(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []string{"hello", "world"}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -457,7 +457,7 @@ func TestUnmarshalRowsString(t *testing.T) {
 func TestUnmarshalRowsBoolPtr(t *testing.T) {
 	yes := true
 	no := false
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*bool{&yes, &no}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -473,7 +473,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
 func TestUnmarshalRowsIntPtr(t *testing.T) {
 	two := 2
 	three := 3
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*int{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -489,7 +489,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
 func TestUnmarshalRowsInt8Ptr(t *testing.T) {
 	two := int8(2)
 	three := int8(3)
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*int8{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -505,7 +505,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
 func TestUnmarshalRowsInt16Ptr(t *testing.T) {
 	two := int16(2)
 	three := int16(3)
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*int16{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -521,7 +521,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
 func TestUnmarshalRowsInt32Ptr(t *testing.T) {
 	two := int32(2)
 	three := int32(3)
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*int32{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -537,7 +537,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
 func TestUnmarshalRowsInt64Ptr(t *testing.T) {
 	two := int64(2)
 	three := int64(3)
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*int64{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -553,7 +553,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
 func TestUnmarshalRowsUintPtr(t *testing.T) {
 	two := uint(2)
 	three := uint(3)
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*uint{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -569,7 +569,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
 func TestUnmarshalRowsUint8Ptr(t *testing.T) {
 	two := uint8(2)
 	three := uint8(3)
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*uint8{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -585,7 +585,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
 func TestUnmarshalRowsUint16Ptr(t *testing.T) {
 	two := uint16(2)
 	three := uint16(3)
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*uint16{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -601,7 +601,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
 func TestUnmarshalRowsUint32Ptr(t *testing.T) {
 	two := uint32(2)
 	three := uint32(3)
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*uint32{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -617,7 +617,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
 func TestUnmarshalRowsUint64Ptr(t *testing.T) {
 	two := uint64(2)
 	three := uint64(3)
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*uint64{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -633,7 +633,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
 func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
 	two := float32(2)
 	three := float32(3)
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*float32{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -649,7 +649,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
 func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
 	two := float64(2)
 	three := float64(3)
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*float64{&two, &three}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -665,7 +665,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
 func TestUnmarshalRowsStringPtr(t *testing.T) {
 	hello := "hello"
 	world := "world"
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		expect := []*string{&hello, &world}
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -697,7 +697,7 @@ func TestUnmarshalRowsStruct(t *testing.T) {
 		Age  int64
 	}
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -736,7 +736,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
 		NullString sql.NullString `db:"value"`
 	}
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
 			"first", "firstnullstring").AddRow("second", nil)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -771,7 +771,7 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) {
 		Name string `db:"name"`
 	}
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -812,7 +812,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
 		Embed
 	}
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -854,7 +854,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
 		*Embed
 	}
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -888,7 +888,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
 		Age  int64
 	}
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -921,7 +921,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
 		Name string `db:"name"`
 	}
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -954,7 +954,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
 		Name string `db:"name"`
 	}
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -969,7 +969,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
 }
 
 func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -1019,7 +1019,7 @@ func TestUnmarshalRowError(t *testing.T) {
 
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
-			runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+			dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 				rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
 				mock.ExpectQuery("select (.+) from users where user=?").WithArgs(
 					"anyone").WillReturnRows(rs)
@@ -1091,7 +1091,7 @@ func TestAnonymousStructPr(t *testing.T) {
 		Name string `db:"name"`
 	}
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{
 			"name",
 			"age",
@@ -1139,7 +1139,7 @@ func TestAnonymousStructPrError(t *testing.T) {
 		Name string `db:"name"`
 	}
 
-	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{
 			"name",
 			"age",
@@ -1154,7 +1154,7 @@ func TestAnonymousStructPrError(t *testing.T) {
 			WithArgs("anyone").WillReturnRows(rs)
 		assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
-		}, "select name, age,grade,discipline,class_name,score from users where user=?",
+		}, "select name, age, grade, discipline, class_name, score from users where user=?",
 			"anyone"))
 		if len(value) > 0 {
 			assert.Equal(t, value[0].score, 0)
@@ -1162,22 +1162,6 @@ func TestAnonymousStructPrError(t *testing.T) {
 	})
 }
 
-func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
-	logx.Disable()
-
-	db, mock, err := sqlmock.New()
-	if err != nil {
-		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
-	}
-	defer db.Close()
-
-	fn(db, mock)
-
-	if err := mock.ExpectationsWereMet(); err != nil {
-		t.Errorf("there were unfulfilled expectations: %s", err)
-	}
-}
-
 type mockedScanner struct {
 	colErr  error
 	scanErr error

+ 7 - 3
core/stores/sqlx/sqlconn.go

@@ -11,9 +11,6 @@ import (
 // spanName is used to identify the span name for the SQL execution.
 const spanName = "sql"
 
-// ErrNotFound is an alias of sql.ErrNoRows
-var ErrNotFound = sql.ErrNoRows
-
 type (
 	// Session stands for raw connections or transaction sessions
 	Session interface {
@@ -131,6 +128,13 @@ func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
 	return conn
 }
 
+// NewSqlConnFromSession returns a SqlConn with the given session.
+func NewSqlConnFromSession(session Session) SqlConn {
+	return txConn{
+		Session: session,
+	}
+}
+
 func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
 	return db.ExecCtx(context.Background(), q, args...)
 }

+ 1 - 1
core/stores/sqlx/sqlconn_test.go

@@ -55,7 +55,7 @@ func TestSqlConn(t *testing.T) {
 }
 
 func buildConn() (mock sqlmock.Sqlmock, err error) {
-	connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
+	_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
 		var db *sql.DB
 		var err error
 		db, mock, err = sqlmock.New()

+ 16 - 0
core/stores/sqlx/tx.go

@@ -15,11 +15,27 @@ type (
 		Rollback() error
 	}
 
+	txConn struct {
+		Session
+	}
+
 	txSession struct {
 		*sql.Tx
 	}
 )
 
+func (s txConn) RawDB() (*sql.DB, error) {
+	return nil, errNoRawDBFromTx
+}
+
+func (s txConn) Transact(_ func(Session) error) error {
+	return errCantNestTx
+}
+
+func (s txConn) TransactCtx(_ context.Context, _ func(context.Context, Session) error) error {
+	return errCantNestTx
+}
+
 // NewSessionFromTx returns a Session with the given sql.Tx.
 // Use it with caution, it's provided for other ORM to interact with.
 func NewSessionFromTx(tx *sql.Tx) Session {

+ 221 - 12
core/stores/sqlx/tx_test.go

@@ -6,7 +6,10 @@ import (
 	"errors"
 	"testing"
 
+	"github.com/DATA-DOG/go-sqlmock"
 	"github.com/stretchr/testify/assert"
+	"github.com/zeromicro/go-zero/core/breaker"
+	"github.com/zeromicro/go-zero/internal/dbtest"
 )
 
 const (
@@ -23,51 +26,51 @@ func (mt *mockTx) Commit() error {
 	return nil
 }
 
-func (mt *mockTx) Exec(q string, args ...any) (sql.Result, error) {
+func (mt *mockTx) Exec(_ string, _ ...any) (sql.Result, error) {
 	return nil, nil
 }
 
-func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
+func (mt *mockTx) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
 	return nil, nil
 }
 
-func (mt *mockTx) Prepare(query string) (StmtSession, error) {
+func (mt *mockTx) Prepare(_ string) (StmtSession, error) {
 	return nil, nil
 }
 
-func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
+func (mt *mockTx) PrepareCtx(_ context.Context, _ string) (StmtSession, error) {
 	return nil, nil
 }
 
-func (mt *mockTx) QueryRow(v any, q string, args ...any) error {
+func (mt *mockTx) QueryRow(_ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (mt *mockTx) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error {
+func (mt *mockTx) QueryRowCtx(_ context.Context, _ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (mt *mockTx) QueryRowPartial(v any, q string, args ...any) error {
+func (mt *mockTx) QueryRowPartial(_ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
+func (mt *mockTx) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (mt *mockTx) QueryRows(v any, q string, args ...any) error {
+func (mt *mockTx) QueryRows(_ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (mt *mockTx) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
+func (mt *mockTx) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (mt *mockTx) QueryRowsPartial(v any, q string, args ...any) error {
+func (mt *mockTx) QueryRowsPartial(_ any, _ string, _ ...any) error {
 	return nil
 }
 
-func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
+func (mt *mockTx) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
 	return nil
 }
 
@@ -101,3 +104,209 @@ func TestTransactRollback(t *testing.T) {
 	assert.Equal(t, mockRollback, mock.status)
 	assert.NotNil(t, err)
 }
+
+func TestTxExceptions(t *testing.T) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		mock.ExpectCommit()
+		conn := NewSqlConnFromDB(db)
+		assert.NoError(t, conn.Transact(func(session Session) error {
+			return nil
+		}))
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		conn := &commonSqlConn{
+			connProv: func() (*sql.DB, error) {
+				return nil, errors.New("foo")
+			},
+			beginTx: begin,
+			onError: func(ctx context.Context, err error) {},
+			brk:     breaker.NewBreaker(),
+		}
+		assert.Error(t, conn.Transact(func(session Session) error {
+			return nil
+		}))
+	})
+
+	runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
+		_, err := conn.RawDB()
+		assert.Equal(t, errNoRawDBFromTx, err)
+		assert.Equal(t, errCantNestTx, conn.Transact(nil))
+		assert.Equal(t, errCantNestTx, conn.TransactCtx(context.Background(), nil))
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		conn := NewSqlConnFromDB(db)
+		assert.Error(t, conn.Transact(func(session Session) error {
+			return errors.New("foo")
+		}))
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		mock.ExpectRollback().WillReturnError(errors.New("foo"))
+		conn := NewSqlConnFromDB(db)
+		assert.Error(t, conn.Transact(func(session Session) error {
+			panic("foo")
+		}))
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		mock.ExpectRollback()
+		conn := NewSqlConnFromDB(db)
+		assert.Error(t, conn.Transact(func(session Session) error {
+			panic(errors.New("foo"))
+		}))
+	})
+}
+
+func TestTxSession(t *testing.T) {
+	runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
+		mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
+		res, err := conn.Exec("any")
+		assert.NoError(t, err)
+		last, err := res.LastInsertId()
+		assert.NoError(t, err)
+		assert.Equal(t, int64(2), last)
+		affected, err := res.RowsAffected()
+		assert.NoError(t, err)
+		assert.Equal(t, int64(3), affected)
+
+		mock.ExpectExec("any").WillReturnError(errors.New("foo"))
+		_, err = conn.Exec("any")
+		assert.Equal(t, "foo", err.Error())
+	})
+
+	runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
+		mock.ExpectPrepare("any")
+		stmt, err := conn.Prepare("any")
+		assert.NoError(t, err)
+		assert.NotNil(t, stmt)
+
+		mock.ExpectPrepare("any").WillReturnError(errors.New("foo"))
+		_, err = conn.Prepare("any")
+		assert.Equal(t, "foo", err.Error())
+	})
+
+	runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
+		rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
+		mock.ExpectQuery("any").WillReturnRows(rows)
+		var val string
+		err := conn.QueryRow(&val, "any")
+		assert.NoError(t, err)
+		assert.Equal(t, "foo", val)
+
+		mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
+		err = conn.QueryRow(&val, "any")
+		assert.Equal(t, "foo", err.Error())
+	})
+
+	runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
+		rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
+		mock.ExpectQuery("any").WillReturnRows(rows)
+		var val string
+		err := conn.QueryRowPartial(&val, "any")
+		assert.NoError(t, err)
+		assert.Equal(t, "foo", val)
+
+		mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
+		err = conn.QueryRowPartial(&val, "any")
+		assert.Equal(t, "foo", err.Error())
+	})
+
+	runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
+		rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
+		mock.ExpectQuery("any").WillReturnRows(rows)
+		var val []string
+		err := conn.QueryRows(&val, "any")
+		assert.NoError(t, err)
+		assert.Equal(t, []string{"foo", "bar"}, val)
+
+		mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
+		err = conn.QueryRows(&val, "any")
+		assert.Equal(t, "foo", err.Error())
+	})
+
+	runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
+		rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
+		mock.ExpectQuery("any").WillReturnRows(rows)
+		var val []string
+		err := conn.QueryRowsPartial(&val, "any")
+		assert.NoError(t, err)
+		assert.Equal(t, []string{"foo", "bar"}, val)
+
+		mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
+		err = conn.QueryRowsPartial(&val, "any")
+		assert.Equal(t, "foo", err.Error())
+	})
+}
+
+func TestTxRollback(t *testing.T) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
+		mock.ExpectQuery("foo").WillReturnError(errors.New("foo"))
+		mock.ExpectRollback()
+
+		conn := NewSqlConnFromDB(db)
+		err := conn.Transact(func(session Session) error {
+			c := NewSqlConnFromSession(session)
+			_, err := c.Exec("any")
+			assert.NoError(t, err)
+
+			var val string
+			return c.QueryRow(&val, "foo")
+		})
+		assert.Error(t, err)
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		mock.ExpectExec("any").WillReturnError(errors.New("foo"))
+		mock.ExpectRollback()
+
+		conn := NewSqlConnFromDB(db)
+		err := conn.Transact(func(session Session) error {
+			c := NewSqlConnFromSession(session)
+			if _, err := c.Exec("any"); err != nil {
+				return err
+			}
+
+			var val string
+			assert.NoError(t, c.QueryRow(&val, "foo"))
+			return nil
+		})
+		assert.Error(t, err)
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
+		mock.ExpectQuery("foo").WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("bar"))
+		mock.ExpectCommit()
+
+		conn := NewSqlConnFromDB(db)
+		err := conn.Transact(func(session Session) error {
+			c := NewSqlConnFromSession(session)
+			_, err := c.Exec("any")
+			assert.NoError(t, err)
+
+			var val string
+			assert.NoError(t, c.QueryRow(&val, "foo"))
+			assert.Equal(t, "bar", val)
+			return nil
+		})
+		assert.NoError(t, err)
+	})
+}
+
+func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) {
+	dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
+		sess := NewSessionFromTx(tx)
+		conn := NewSqlConnFromSession(sess)
+		f(conn, mock)
+	})
+}

+ 37 - 0
internal/dbtest/sql.go

@@ -0,0 +1,37 @@
+package dbtest
+
+import (
+	"database/sql"
+	"testing"
+
+	"github.com/DATA-DOG/go-sqlmock"
+	"github.com/stretchr/testify/assert"
+)
+
+// RunTest runs a test function with a mock database.
+func RunTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
+	db, mock, err := sqlmock.New()
+	if err != nil {
+		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+	}
+	defer func() {
+		_ = db.Close()
+	}()
+
+	fn(db, mock)
+
+	if err = mock.ExpectationsWereMet(); err != nil {
+		t.Errorf("there were unfulfilled expectations: %s", err)
+	}
+}
+
+// RunTxTest runs a test function with a mock database in a transaction.
+func RunTxTest(t *testing.T, f func(tx *sql.Tx, mock sqlmock.Sqlmock)) {
+	RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		mock.ExpectBegin()
+		tx, err := db.Begin()
+		if assert.NoError(t, err) {
+			f(tx, mock)
+		}
+	})
+}