tx_test.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. )
  9. const (
  10. mockCommit = 1
  11. mockRollback = 2
  12. )
  13. type mockTx struct {
  14. status int
  15. }
  16. func (mt *mockTx) Commit() error {
  17. mt.status |= mockCommit
  18. return nil
  19. }
  20. func (mt *mockTx) Exec(q string, args ...any) (sql.Result, error) {
  21. return nil, nil
  22. }
  23. func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
  24. return nil, nil
  25. }
  26. func (mt *mockTx) Prepare(query string) (StmtSession, error) {
  27. return nil, nil
  28. }
  29. func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
  30. return nil, nil
  31. }
  32. func (mt *mockTx) QueryRow(v any, q string, args ...any) error {
  33. return nil
  34. }
  35. func (mt *mockTx) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error {
  36. return nil
  37. }
  38. func (mt *mockTx) QueryRowPartial(v any, q string, args ...any) error {
  39. return nil
  40. }
  41. func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
  42. return nil
  43. }
  44. func (mt *mockTx) QueryRows(v any, q string, args ...any) error {
  45. return nil
  46. }
  47. func (mt *mockTx) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
  48. return nil
  49. }
  50. func (mt *mockTx) QueryRowsPartial(v any, q string, args ...any) error {
  51. return nil
  52. }
  53. func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
  54. return nil
  55. }
  56. func (mt *mockTx) Rollback() error {
  57. mt.status |= mockRollback
  58. return nil
  59. }
  60. func beginMock(mock *mockTx) beginnable {
  61. return func(*sql.DB) (trans, error) {
  62. return mock, nil
  63. }
  64. }
  65. func TestTransactCommit(t *testing.T) {
  66. mock := &mockTx{}
  67. err := transactOnConn(context.Background(), nil, beginMock(mock),
  68. func(context.Context, Session) error {
  69. return nil
  70. })
  71. assert.Equal(t, mockCommit, mock.status)
  72. assert.Nil(t, err)
  73. }
  74. func TestTransactRollback(t *testing.T) {
  75. mock := &mockTx{}
  76. err := transactOnConn(context.Background(), nil, beginMock(mock),
  77. func(context.Context, Session) error {
  78. return errors.New("rollback")
  79. })
  80. assert.Equal(t, mockRollback, mock.status)
  81. assert.NotNil(t, err)
  82. }