tx.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. )
  7. type (
  8. beginnable func(*sql.DB) (trans, error)
  9. trans interface {
  10. Session
  11. Commit() error
  12. Rollback() error
  13. }
  14. txSession struct {
  15. *sql.Tx
  16. }
  17. )
  18. // NewSessionFromTx returns a Session with the given sql.Tx.
  19. // Use it with caution, it's provided for other ORM to interact with.
  20. func NewSessionFromTx(tx *sql.Tx) Session {
  21. return txSession{Tx: tx}
  22. }
  23. func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
  24. return t.ExecCtx(context.Background(), q, args...)
  25. }
  26. func (t txSession) ExecCtx(ctx context.Context, q string, args ...interface{}) (result sql.Result, err error) {
  27. ctx, span := startSpan(ctx, "Exec")
  28. defer func() {
  29. endSpan(span, err)
  30. }()
  31. result, err = exec(ctx, t.Tx, q, args...)
  32. return
  33. }
  34. func (t txSession) Prepare(q string) (StmtSession, error) {
  35. return t.PrepareCtx(context.Background(), q)
  36. }
  37. func (t txSession) PrepareCtx(ctx context.Context, q string) (stmtSession StmtSession, err error) {
  38. ctx, span := startSpan(ctx, "Prepare")
  39. defer func() {
  40. endSpan(span, err)
  41. }()
  42. stmt, err := t.Tx.PrepareContext(ctx, q)
  43. if err != nil {
  44. return nil, err
  45. }
  46. return statement{
  47. query: q,
  48. stmt: stmt,
  49. }, nil
  50. }
  51. func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
  52. return t.QueryRowCtx(context.Background(), v, q, args...)
  53. }
  54. func (t txSession) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) (err error) {
  55. ctx, span := startSpan(ctx, "QueryRow")
  56. defer func() {
  57. endSpan(span, err)
  58. }()
  59. return query(ctx, t.Tx, func(rows *sql.Rows) error {
  60. return unmarshalRow(v, rows, true)
  61. }, q, args...)
  62. }
  63. func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  64. return t.QueryRowPartialCtx(context.Background(), v, q, args...)
  65. }
  66. func (t txSession) QueryRowPartialCtx(ctx context.Context, v interface{}, q string,
  67. args ...interface{}) (err error) {
  68. ctx, span := startSpan(ctx, "QueryRowPartial")
  69. defer func() {
  70. endSpan(span, err)
  71. }()
  72. return query(ctx, t.Tx, func(rows *sql.Rows) error {
  73. return unmarshalRow(v, rows, false)
  74. }, q, args...)
  75. }
  76. func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
  77. return t.QueryRowsCtx(context.Background(), v, q, args...)
  78. }
  79. func (t txSession) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) (err error) {
  80. ctx, span := startSpan(ctx, "QueryRows")
  81. defer func() {
  82. endSpan(span, err)
  83. }()
  84. return query(ctx, t.Tx, func(rows *sql.Rows) error {
  85. return unmarshalRows(v, rows, true)
  86. }, q, args...)
  87. }
  88. func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  89. return t.QueryRowsPartialCtx(context.Background(), v, q, args...)
  90. }
  91. func (t txSession) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string,
  92. args ...interface{}) (err error) {
  93. ctx, span := startSpan(ctx, "QueryRowsPartial")
  94. defer func() {
  95. endSpan(span, err)
  96. }()
  97. return query(ctx, t.Tx, func(rows *sql.Rows) error {
  98. return unmarshalRows(v, rows, false)
  99. }, q, args...)
  100. }
  101. func begin(db *sql.DB) (trans, error) {
  102. tx, err := db.Begin()
  103. if err != nil {
  104. return nil, err
  105. }
  106. return txSession{
  107. Tx: tx,
  108. }, nil
  109. }
  110. func transact(ctx context.Context, db *commonSqlConn, b beginnable,
  111. fn func(context.Context, Session) error) (err error) {
  112. conn, err := db.connProv()
  113. if err != nil {
  114. db.onError(err)
  115. return err
  116. }
  117. return transactOnConn(ctx, conn, b, fn)
  118. }
  119. func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable,
  120. fn func(context.Context, Session) error) (err error) {
  121. var tx trans
  122. tx, err = b(conn)
  123. if err != nil {
  124. return
  125. }
  126. defer func() {
  127. if p := recover(); p != nil {
  128. if e := tx.Rollback(); e != nil {
  129. err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e)
  130. } else {
  131. err = fmt.Errorf("recoveer from %#v", p)
  132. }
  133. } else if err != nil {
  134. if e := tx.Rollback(); e != nil {
  135. err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e)
  136. }
  137. } else {
  138. err = tx.Commit()
  139. }
  140. }()
  141. return fn(ctx, tx)
  142. }