tx.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. )
  7. type (
  8. beginnable func(*sql.DB) (trans, error)
  9. trans interface {
  10. Session
  11. Commit() error
  12. Rollback() error
  13. }
  14. txConn struct {
  15. Session
  16. }
  17. txSession struct {
  18. *sql.Tx
  19. }
  20. )
  21. func (s txConn) RawDB() (*sql.DB, error) {
  22. return nil, errNoRawDBFromTx
  23. }
  24. func (s txConn) Transact(_ func(Session) error) error {
  25. return errCantNestTx
  26. }
  27. func (s txConn) TransactCtx(_ context.Context, _ func(context.Context, Session) error) error {
  28. return errCantNestTx
  29. }
  30. // NewSessionFromTx returns a Session with the given sql.Tx.
  31. // Use it with caution, it's provided for other ORM to interact with.
  32. func NewSessionFromTx(tx *sql.Tx) Session {
  33. return txSession{Tx: tx}
  34. }
  35. func (t txSession) Exec(q string, args ...any) (sql.Result, error) {
  36. return t.ExecCtx(context.Background(), q, args...)
  37. }
  38. func (t txSession) ExecCtx(ctx context.Context, q string, args ...any) (result sql.Result, err error) {
  39. ctx, span := startSpan(ctx, "Exec")
  40. defer func() {
  41. endSpan(span, err)
  42. }()
  43. result, err = exec(ctx, t.Tx, q, args...)
  44. return
  45. }
  46. func (t txSession) Prepare(q string) (StmtSession, error) {
  47. return t.PrepareCtx(context.Background(), q)
  48. }
  49. func (t txSession) PrepareCtx(ctx context.Context, q string) (stmtSession StmtSession, err error) {
  50. ctx, span := startSpan(ctx, "Prepare")
  51. defer func() {
  52. endSpan(span, err)
  53. }()
  54. stmt, err := t.Tx.PrepareContext(ctx, q)
  55. if err != nil {
  56. return nil, err
  57. }
  58. return statement{
  59. query: q,
  60. stmt: stmt,
  61. }, nil
  62. }
  63. func (t txSession) QueryRow(v any, q string, args ...any) error {
  64. return t.QueryRowCtx(context.Background(), v, q, args...)
  65. }
  66. func (t txSession) QueryRowCtx(ctx context.Context, v any, q string, args ...any) (err error) {
  67. ctx, span := startSpan(ctx, "QueryRow")
  68. defer func() {
  69. endSpan(span, err)
  70. }()
  71. return query(ctx, t.Tx, func(rows *sql.Rows) error {
  72. return unmarshalRow(v, rows, true)
  73. }, q, args...)
  74. }
  75. func (t txSession) QueryRowPartial(v any, q string, args ...any) error {
  76. return t.QueryRowPartialCtx(context.Background(), v, q, args...)
  77. }
  78. func (t txSession) QueryRowPartialCtx(ctx context.Context, v any, q string,
  79. args ...any) (err error) {
  80. ctx, span := startSpan(ctx, "QueryRowPartial")
  81. defer func() {
  82. endSpan(span, err)
  83. }()
  84. return query(ctx, t.Tx, func(rows *sql.Rows) error {
  85. return unmarshalRow(v, rows, false)
  86. }, q, args...)
  87. }
  88. func (t txSession) QueryRows(v any, q string, args ...any) error {
  89. return t.QueryRowsCtx(context.Background(), v, q, args...)
  90. }
  91. func (t txSession) QueryRowsCtx(ctx context.Context, v any, q string, args ...any) (err error) {
  92. ctx, span := startSpan(ctx, "QueryRows")
  93. defer func() {
  94. endSpan(span, err)
  95. }()
  96. return query(ctx, t.Tx, func(rows *sql.Rows) error {
  97. return unmarshalRows(v, rows, true)
  98. }, q, args...)
  99. }
  100. func (t txSession) QueryRowsPartial(v any, q string, args ...any) error {
  101. return t.QueryRowsPartialCtx(context.Background(), v, q, args...)
  102. }
  103. func (t txSession) QueryRowsPartialCtx(ctx context.Context, v any, q string,
  104. args ...any) (err error) {
  105. ctx, span := startSpan(ctx, "QueryRowsPartial")
  106. defer func() {
  107. endSpan(span, err)
  108. }()
  109. return query(ctx, t.Tx, func(rows *sql.Rows) error {
  110. return unmarshalRows(v, rows, false)
  111. }, q, args...)
  112. }
  113. func begin(db *sql.DB) (trans, error) {
  114. tx, err := db.Begin()
  115. if err != nil {
  116. return nil, err
  117. }
  118. return txSession{
  119. Tx: tx,
  120. }, nil
  121. }
  122. func transact(ctx context.Context, db *commonSqlConn, b beginnable,
  123. fn func(context.Context, Session) error) (err error) {
  124. conn, err := db.connProv()
  125. if err != nil {
  126. db.onError(ctx, err)
  127. return err
  128. }
  129. return transactOnConn(ctx, conn, b, fn)
  130. }
  131. func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable,
  132. fn func(context.Context, Session) error) (err error) {
  133. var tx trans
  134. tx, err = b(conn)
  135. if err != nil {
  136. return
  137. }
  138. defer func() {
  139. if p := recover(); p != nil {
  140. if e := tx.Rollback(); e != nil {
  141. err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e)
  142. } else {
  143. err = fmt.Errorf("recover from %#v", p)
  144. }
  145. } else if err != nil {
  146. if e := tx.Rollback(); e != nil {
  147. err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e)
  148. }
  149. } else {
  150. err = tx.Commit()
  151. }
  152. }()
  153. return fn(ctx, tx)
  154. }