123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312 |
- package sqlx
- import (
- "context"
- "database/sql"
- "errors"
- "testing"
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/stretchr/testify/assert"
- "github.com/wuntsong-org/go-zero-plus/core/breaker"
- "github.com/wuntsong-org/go-zero-plus/core/stores/dbtest"
- )
- const (
- mockCommit = 1
- mockRollback = 2
- )
- type mockTx struct {
- status int
- }
- func (mt *mockTx) Commit() error {
- mt.status |= mockCommit
- return nil
- }
- func (mt *mockTx) Exec(_ string, _ ...any) (sql.Result, error) {
- return nil, nil
- }
- func (mt *mockTx) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
- return nil, nil
- }
- func (mt *mockTx) Prepare(_ string) (StmtSession, error) {
- return nil, nil
- }
- func (mt *mockTx) PrepareCtx(_ context.Context, _ string) (StmtSession, error) {
- return nil, nil
- }
- func (mt *mockTx) QueryRow(_ any, _ string, _ ...any) error {
- return nil
- }
- func (mt *mockTx) QueryRowCtx(_ context.Context, _ any, _ string, _ ...any) error {
- return nil
- }
- func (mt *mockTx) QueryRowPartial(_ any, _ string, _ ...any) error {
- return nil
- }
- func (mt *mockTx) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
- return nil
- }
- func (mt *mockTx) QueryRows(_ any, _ string, _ ...any) error {
- return nil
- }
- func (mt *mockTx) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
- return nil
- }
- func (mt *mockTx) QueryRowsPartial(_ any, _ string, _ ...any) error {
- return nil
- }
- func (mt *mockTx) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
- return nil
- }
- func (mt *mockTx) Rollback() error {
- mt.status |= mockRollback
- return nil
- }
- func beginMock(mock *mockTx) beginnable {
- return func(*sql.DB) (trans, error) {
- return mock, nil
- }
- }
- func TestTransactCommit(t *testing.T) {
- mock := &mockTx{}
- err := transactOnConn(context.Background(), nil, beginMock(mock),
- func(context.Context, Session) error {
- return nil
- })
- assert.Equal(t, mockCommit, mock.status)
- assert.Nil(t, err)
- }
- func TestTransactRollback(t *testing.T) {
- mock := &mockTx{}
- err := transactOnConn(context.Background(), nil, beginMock(mock),
- func(context.Context, Session) error {
- return errors.New("rollback")
- })
- assert.Equal(t, mockRollback, mock.status)
- assert.NotNil(t, err)
- }
- func TestTxExceptions(t *testing.T) {
- dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
- mock.ExpectBegin()
- mock.ExpectCommit()
- conn := NewSqlConnFromDB(db)
- assert.NoError(t, conn.Transact(func(session Session) error {
- return nil
- }))
- })
- dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
- conn := &commonSqlConn{
- connProv: func() (*sql.DB, error) {
- return nil, errors.New("foo")
- },
- beginTx: begin,
- onError: func(ctx context.Context, err error) {},
- brk: breaker.NewBreaker(),
- }
- assert.Error(t, conn.Transact(func(session Session) error {
- return nil
- }))
- })
- runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
- _, err := conn.RawDB()
- assert.Equal(t, errNoRawDBFromTx, err)
- assert.Equal(t, errCantNestTx, conn.Transact(nil))
- assert.Equal(t, errCantNestTx, conn.TransactCtx(context.Background(), nil))
- })
- dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
- mock.ExpectBegin()
- conn := NewSqlConnFromDB(db)
- assert.Error(t, conn.Transact(func(session Session) error {
- return errors.New("foo")
- }))
- })
- dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
- mock.ExpectBegin()
- mock.ExpectRollback().WillReturnError(errors.New("foo"))
- conn := NewSqlConnFromDB(db)
- assert.Error(t, conn.Transact(func(session Session) error {
- panic("foo")
- }))
- })
- dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
- mock.ExpectBegin()
- mock.ExpectRollback()
- conn := NewSqlConnFromDB(db)
- assert.Error(t, conn.Transact(func(session Session) error {
- panic(errors.New("foo"))
- }))
- })
- }
- func TestTxSession(t *testing.T) {
- runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
- mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
- res, err := conn.Exec("any")
- assert.NoError(t, err)
- last, err := res.LastInsertId()
- assert.NoError(t, err)
- assert.Equal(t, int64(2), last)
- affected, err := res.RowsAffected()
- assert.NoError(t, err)
- assert.Equal(t, int64(3), affected)
- mock.ExpectExec("any").WillReturnError(errors.New("foo"))
- _, err = conn.Exec("any")
- assert.Equal(t, "foo", err.Error())
- })
- runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
- mock.ExpectPrepare("any")
- stmt, err := conn.Prepare("any")
- assert.NoError(t, err)
- assert.NotNil(t, stmt)
- mock.ExpectPrepare("any").WillReturnError(errors.New("foo"))
- _, err = conn.Prepare("any")
- assert.Equal(t, "foo", err.Error())
- })
- runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
- rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
- mock.ExpectQuery("any").WillReturnRows(rows)
- var val string
- err := conn.QueryRow(&val, "any")
- assert.NoError(t, err)
- assert.Equal(t, "foo", val)
- mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
- err = conn.QueryRow(&val, "any")
- assert.Equal(t, "foo", err.Error())
- })
- runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
- rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
- mock.ExpectQuery("any").WillReturnRows(rows)
- var val string
- err := conn.QueryRowPartial(&val, "any")
- assert.NoError(t, err)
- assert.Equal(t, "foo", val)
- mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
- err = conn.QueryRowPartial(&val, "any")
- assert.Equal(t, "foo", err.Error())
- })
- runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
- rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
- mock.ExpectQuery("any").WillReturnRows(rows)
- var val []string
- err := conn.QueryRows(&val, "any")
- assert.NoError(t, err)
- assert.Equal(t, []string{"foo", "bar"}, val)
- mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
- err = conn.QueryRows(&val, "any")
- assert.Equal(t, "foo", err.Error())
- })
- runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
- rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
- mock.ExpectQuery("any").WillReturnRows(rows)
- var val []string
- err := conn.QueryRowsPartial(&val, "any")
- assert.NoError(t, err)
- assert.Equal(t, []string{"foo", "bar"}, val)
- mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
- err = conn.QueryRowsPartial(&val, "any")
- assert.Equal(t, "foo", err.Error())
- })
- }
- func TestTxRollback(t *testing.T) {
- dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
- mock.ExpectBegin()
- mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
- mock.ExpectQuery("foo").WillReturnError(errors.New("foo"))
- mock.ExpectRollback()
- conn := NewSqlConnFromDB(db)
- err := conn.Transact(func(session Session) error {
- c := NewSqlConnFromSession(session)
- _, err := c.Exec("any")
- assert.NoError(t, err)
- var val string
- return c.QueryRow(&val, "foo")
- })
- assert.Error(t, err)
- })
- dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
- mock.ExpectBegin()
- mock.ExpectExec("any").WillReturnError(errors.New("foo"))
- mock.ExpectRollback()
- conn := NewSqlConnFromDB(db)
- err := conn.Transact(func(session Session) error {
- c := NewSqlConnFromSession(session)
- if _, err := c.Exec("any"); err != nil {
- return err
- }
- var val string
- assert.NoError(t, c.QueryRow(&val, "foo"))
- return nil
- })
- assert.Error(t, err)
- })
- dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
- mock.ExpectBegin()
- mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
- mock.ExpectQuery("foo").WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("bar"))
- mock.ExpectCommit()
- conn := NewSqlConnFromDB(db)
- err := conn.Transact(func(session Session) error {
- c := NewSqlConnFromSession(session)
- _, err := c.Exec("any")
- assert.NoError(t, err)
- var val string
- assert.NoError(t, c.QueryRow(&val, "foo"))
- assert.Equal(t, "bar", val)
- return nil
- })
- assert.NoError(t, err)
- })
- }
- func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) {
- dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
- sess := NewSessionFromTx(tx)
- conn := NewSqlConnFromSession(sess)
- f(conn, mock)
- })
- }
|