tx.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "fmt"
  5. )
  6. type (
  7. beginnable func(*sql.DB) (trans, error)
  8. trans interface {
  9. Session
  10. Commit() error
  11. Rollback() error
  12. }
  13. txSession struct {
  14. *sql.Tx
  15. }
  16. )
  17. func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
  18. return exec(t.Tx, q, args...)
  19. }
  20. func (t txSession) Prepare(q string) (StmtSession, error) {
  21. stmt, err := t.Tx.Prepare(q)
  22. if err != nil {
  23. return nil, err
  24. }
  25. return statement{
  26. query: q,
  27. stmt: stmt,
  28. }, nil
  29. }
  30. func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
  31. return query(t.Tx, func(rows *sql.Rows) error {
  32. return unmarshalRow(v, rows, true)
  33. }, q, args...)
  34. }
  35. func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  36. return query(t.Tx, func(rows *sql.Rows) error {
  37. return unmarshalRow(v, rows, false)
  38. }, q, args...)
  39. }
  40. func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
  41. return query(t.Tx, func(rows *sql.Rows) error {
  42. return unmarshalRows(v, rows, true)
  43. }, q, args...)
  44. }
  45. func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  46. return query(t.Tx, func(rows *sql.Rows) error {
  47. return unmarshalRows(v, rows, false)
  48. }, q, args...)
  49. }
  50. func begin(db *sql.DB) (trans, error) {
  51. tx, err := db.Begin()
  52. if err != nil {
  53. return nil, err
  54. }
  55. return txSession{
  56. Tx: tx,
  57. }, nil
  58. }
  59. func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
  60. conn, err := db.connProv()
  61. if err != nil {
  62. db.onError(err)
  63. return err
  64. }
  65. return transactOnConn(conn, b, fn)
  66. }
  67. func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
  68. var tx trans
  69. tx, err = b(conn)
  70. if err != nil {
  71. return
  72. }
  73. defer func() {
  74. if p := recover(); p != nil {
  75. if e := tx.Rollback(); e != nil {
  76. err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
  77. } else {
  78. err = fmt.Errorf("recoveer from %#v", p)
  79. }
  80. } else if err != nil {
  81. if e := tx.Rollback(); e != nil {
  82. err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
  83. }
  84. } else {
  85. err = tx.Commit()
  86. }
  87. }()
  88. return fn(tx)
  89. }