sqlconn.go 6.3 KB


  1. package sqlx
  2. import (
  3. "database/sql"
  4. "github.com/tal-tech/go-zero/core/breaker"
  5. "github.com/tal-tech/go-zero/core/logx"
  6. )
  7. // ErrNotFound is an alias of sql.ErrNoRows
  8. var ErrNotFound = sql.ErrNoRows
  9. type (
  10. // Session stands for raw connections or transaction sessions
  11. Session interface {
  12. Exec(query string, args ...interface{}) (sql.Result, error)
  13. Prepare(query string) (StmtSession, error)
  14. QueryRow(v interface{}, query string, args ...interface{}) error
  15. QueryRowPartial(v interface{}, query string, args ...interface{}) error
  16. QueryRows(v interface{}, query string, args ...interface{}) error
  17. QueryRowsPartial(v interface{}, query string, args ...interface{}) error
  18. }
  19. // SqlConn only stands for raw connections, so Transact method can be called.
  20. SqlConn interface {
  21. Session
  22. // RawDB is for other ORM to operate with, use it with caution.
  23. // Notice: don't close it.
  24. RawDB() (*sql.DB, error)
  25. Transact(func(session Session) error) error
  26. }
  27. // SqlOption defines the method to customize a sql connection.
  28. SqlOption func(*commonSqlConn)
  29. // StmtSession interface represents a session that can be used to execute statements.
  30. StmtSession interface {
  31. Close() error
  32. Exec(args ...interface{}) (sql.Result, error)
  33. QueryRow(v interface{}, args ...interface{}) error
  34. QueryRowPartial(v interface{}, args ...interface{}) error
  35. QueryRows(v interface{}, args ...interface{}) error
  36. QueryRowsPartial(v interface{}, args ...interface{}) error
  37. }
  38. // thread-safe
  39. // Because CORBA doesn't support PREPARE, so we need to combine the
  40. // query arguments into one string and do underlying query without arguments
  41. commonSqlConn struct {
  42. connProv connProvider
  43. onError func(error)
  44. beginTx beginnable
  45. brk breaker.Breaker
  46. accept func(error) bool
  47. }
  48. connProvider func() (*sql.DB, error)
  49. sessionConn interface {
  50. Exec(query string, args ...interface{}) (sql.Result, error)
  51. Query(query string, args ...interface{}) (*sql.Rows, error)
  52. }
  53. statement struct {
  54. query string
  55. stmt *sql.Stmt
  56. }
  57. stmtConn interface {
  58. Exec(args ...interface{}) (sql.Result, error)
  59. Query(args ...interface{}) (*sql.Rows, error)
  60. }
  61. )
  62. // NewSqlConn returns a SqlConn with given driver name and datasource.
  63. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
  64. conn := &commonSqlConn{
  65. connProv: func() (*sql.DB, error) {
  66. return getSqlConn(driverName, datasource)
  67. },
  68. onError: func(err error) {
  69. logInstanceError(datasource, err)
  70. },
  71. beginTx: begin,
  72. brk: breaker.NewBreaker(),
  73. }
  74. for _, opt := range opts {
  75. opt(conn)
  76. }
  77. return conn
  78. }
  79. // NewSqlConnFromDB returns a SqlConn with the given sql.DB.
  80. // Use it with caution, it's provided for other ORM to interact with.
  81. func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
  82. conn := &commonSqlConn{
  83. connProv: func() (*sql.DB, error) {
  84. return db, nil
  85. },
  86. onError: func(err error) {
  87. logx.Errorf("Error on getting sql instance: %v", err)
  88. },
  89. beginTx: begin,
  90. brk: breaker.NewBreaker(),
  91. }
  92. for _, opt := range opts {
  93. opt(conn)
  94. }
  95. return conn
  96. }
  97. func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
  98. err = db.brk.DoWithAcceptable(func() error {
  99. var conn *sql.DB
  100. conn, err = db.connProv()
  101. if err != nil {
  102. db.onError(err)
  103. return err
  104. }
  105. result, err = exec(conn, q, args...)
  106. return err
  107. }, db.acceptable)
  108. return
  109. }
  110. func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
  111. err = db.brk.DoWithAcceptable(func() error {
  112. var conn *sql.DB
  113. conn, err = db.connProv()
  114. if err != nil {
  115. db.onError(err)
  116. return err
  117. }
  118. st, err := conn.Prepare(query)
  119. if err != nil {
  120. return err
  121. }
  122. stmt = statement{
  123. query: query,
  124. stmt: st,
  125. }
  126. return nil
  127. }, db.acceptable)
  128. return
  129. }
  130. func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
  131. return db.queryRows(func(rows *sql.Rows) error {
  132. return unmarshalRow(v, rows, true)
  133. }, q, args...)
  134. }
  135. func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  136. return db.queryRows(func(rows *sql.Rows) error {
  137. return unmarshalRow(v, rows, false)
  138. }, q, args...)
  139. }
  140. func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
  141. return db.queryRows(func(rows *sql.Rows) error {
  142. return unmarshalRows(v, rows, true)
  143. }, q, args...)
  144. }
  145. func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  146. return db.queryRows(func(rows *sql.Rows) error {
  147. return unmarshalRows(v, rows, false)
  148. }, q, args...)
  149. }
  150. func (db *commonSqlConn) RawDB() (*sql.DB, error) {
  151. return db.connProv()
  152. }
  153. func (db *commonSqlConn) Transact(fn func(Session) error) error {
  154. return db.brk.DoWithAcceptable(func() error {
  155. return transact(db, db.beginTx, fn)
  156. }, db.acceptable)
  157. }
  158. func (db *commonSqlConn) acceptable(err error) bool {
  159. ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone
  160. if db.accept == nil {
  161. return ok
  162. }
  163. return ok || db.accept(err)
  164. }
  165. func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
  166. var qerr error
  167. return db.brk.DoWithAcceptable(func() error {
  168. conn, err := db.connProv()
  169. if err != nil {
  170. db.onError(err)
  171. return err
  172. }
  173. return query(conn, func(rows *sql.Rows) error {
  174. qerr = scanner(rows)
  175. return qerr
  176. }, q, args...)
  177. }, func(err error) bool {
  178. return qerr == err || db.acceptable(err)
  179. })
  180. }
  181. func (s statement) Close() error {
  182. return s.stmt.Close()
  183. }
  184. func (s statement) Exec(args ...interface{}) (sql.Result, error) {
  185. return execStmt(s.stmt, s.query, args...)
  186. }
  187. func (s statement) QueryRow(v interface{}, args ...interface{}) error {
  188. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  189. return unmarshalRow(v, rows, true)
  190. }, s.query, args...)
  191. }
  192. func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
  193. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  194. return unmarshalRow(v, rows, false)
  195. }, s.query, args...)
  196. }
  197. func (s statement) QueryRows(v interface{}, args ...interface{}) error {
  198. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  199. return unmarshalRows(v, rows, true)
  200. }, s.query, args...)
  201. }
  202. func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
  203. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  204. return unmarshalRows(v, rows, false)
  205. }, s.query, args...)
  206. }