tx.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "fmt"
  5. )
  6. type (
  7. beginnable func(*sql.DB) (trans, error)
  8. trans interface {
  9. Session
  10. Commit() error
  11. Rollback() error
  12. }
  13. txSession struct {
  14. *sql.Tx
  15. }
  16. )
  17. // NewSessionFromTx returns a Session with the given sql.Tx.
  18. // Use it with caution, it's provided for other ORM to interact with.
  19. func NewSessionFromTx(tx *sql.Tx) Session {
  20. return txSession{Tx: tx}
  21. }
  22. func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
  23. return exec(t.Tx, q, args...)
  24. }
  25. func (t txSession) Prepare(q string) (StmtSession, error) {
  26. stmt, err := t.Tx.Prepare(q)
  27. if err != nil {
  28. return nil, err
  29. }
  30. return statement{
  31. query: q,
  32. stmt: stmt,
  33. }, nil
  34. }
  35. func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
  36. return query(t.Tx, func(rows *sql.Rows) error {
  37. return unmarshalRow(v, rows, true)
  38. }, q, args...)
  39. }
  40. func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  41. return query(t.Tx, func(rows *sql.Rows) error {
  42. return unmarshalRow(v, rows, false)
  43. }, q, args...)
  44. }
  45. func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
  46. return query(t.Tx, func(rows *sql.Rows) error {
  47. return unmarshalRows(v, rows, true)
  48. }, q, args...)
  49. }
  50. func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  51. return query(t.Tx, func(rows *sql.Rows) error {
  52. return unmarshalRows(v, rows, false)
  53. }, q, args...)
  54. }
  55. func begin(db *sql.DB) (trans, error) {
  56. tx, err := db.Begin()
  57. if err != nil {
  58. return nil, err
  59. }
  60. return txSession{
  61. Tx: tx,
  62. }, nil
  63. }
  64. func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
  65. conn, err := db.connProv()
  66. if err != nil {
  67. db.onError(err)
  68. return err
  69. }
  70. return transactOnConn(conn, b, fn)
  71. }
  72. func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
  73. var tx trans
  74. tx, err = b(conn)
  75. if err != nil {
  76. return
  77. }
  78. defer func() {
  79. if p := recover(); p != nil {
  80. if e := tx.Rollback(); e != nil {
  81. err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
  82. } else {
  83. err = fmt.Errorf("recoveer from %#v", p)
  84. }
  85. } else if err != nil {
  86. if e := tx.Rollback(); e != nil {
  87. err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
  88. }
  89. } else {
  90. err = tx.Commit()
  91. }
  92. }()
  93. return fn(tx)
  94. }