sqlconn.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "github.com/zeromicro/go-zero/core/breaker"
  6. "github.com/zeromicro/go-zero/core/logx"
  7. )
  8. // ErrNotFound is an alias of sql.ErrNoRows
  9. var ErrNotFound = sql.ErrNoRows
  10. type (
  11. // Session stands for raw connections or transaction sessions
  12. Session interface {
  13. Exec(query string, args ...interface{}) (sql.Result, error)
  14. ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
  15. Prepare(query string) (StmtSession, error)
  16. PrepareCtx(ctx context.Context, query string) (StmtSession, error)
  17. QueryRow(v interface{}, query string, args ...interface{}) error
  18. QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  19. QueryRowPartial(v interface{}, query string, args ...interface{}) error
  20. QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  21. QueryRows(v interface{}, query string, args ...interface{}) error
  22. QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  23. QueryRowsPartial(v interface{}, query string, args ...interface{}) error
  24. QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  25. }
  26. // SqlConn only stands for raw connections, so Transact method can be called.
  27. SqlConn interface {
  28. Session
  29. // RawDB is for other ORM to operate with, use it with caution.
  30. // Notice: don't close it.
  31. RawDB() (*sql.DB, error)
  32. Transact(fn func(Session) error) error
  33. TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
  34. }
  35. // SqlOption defines the method to customize a sql connection.
  36. SqlOption func(*commonSqlConn)
  37. // StmtSession interface represents a session that can be used to execute statements.
  38. StmtSession interface {
  39. Close() error
  40. Exec(args ...interface{}) (sql.Result, error)
  41. ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error)
  42. QueryRow(v interface{}, args ...interface{}) error
  43. QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error
  44. QueryRowPartial(v interface{}, args ...interface{}) error
  45. QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
  46. QueryRows(v interface{}, args ...interface{}) error
  47. QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error
  48. QueryRowsPartial(v interface{}, args ...interface{}) error
  49. QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
  50. }
  51. // thread-safe
  52. // Because CORBA doesn't support PREPARE, so we need to combine the
  53. // query arguments into one string and do underlying query without arguments
  54. commonSqlConn struct {
  55. connProv connProvider
  56. onError func(error)
  57. beginTx beginnable
  58. brk breaker.Breaker
  59. accept func(error) bool
  60. }
  61. connProvider func() (*sql.DB, error)
  62. sessionConn interface {
  63. Exec(query string, args ...interface{}) (sql.Result, error)
  64. ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
  65. Query(query string, args ...interface{}) (*sql.Rows, error)
  66. QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
  67. }
  68. statement struct {
  69. query string
  70. stmt *sql.Stmt
  71. }
  72. stmtConn interface {
  73. Exec(args ...interface{}) (sql.Result, error)
  74. ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
  75. Query(args ...interface{}) (*sql.Rows, error)
  76. QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
  77. }
  78. )
  79. // NewSqlConn returns a SqlConn with given driver name and datasource.
  80. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
  81. conn := &commonSqlConn{
  82. connProv: func() (*sql.DB, error) {
  83. return getSqlConn(driverName, datasource)
  84. },
  85. onError: func(err error) {
  86. logInstanceError(datasource, err)
  87. },
  88. beginTx: begin,
  89. brk: breaker.NewBreaker(),
  90. }
  91. for _, opt := range opts {
  92. opt(conn)
  93. }
  94. return conn
  95. }
  96. // NewSqlConnFromDB returns a SqlConn with the given sql.DB.
  97. // Use it with caution, it's provided for other ORM to interact with.
  98. func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
  99. conn := &commonSqlConn{
  100. connProv: func() (*sql.DB, error) {
  101. return db, nil
  102. },
  103. onError: func(err error) {
  104. logx.Errorf("Error on getting sql instance: %v", err)
  105. },
  106. beginTx: begin,
  107. brk: breaker.NewBreaker(),
  108. }
  109. for _, opt := range opts {
  110. opt(conn)
  111. }
  112. return conn
  113. }
  114. func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
  115. return db.ExecCtx(context.Background(), q, args...)
  116. }
  117. func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...interface{}) (
  118. result sql.Result, err error) {
  119. err = db.brk.DoWithAcceptable(func() error {
  120. var conn *sql.DB
  121. conn, err = db.connProv()
  122. if err != nil {
  123. db.onError(err)
  124. return err
  125. }
  126. result, err = exec(ctx, conn, q, args...)
  127. return err
  128. }, db.acceptable)
  129. return
  130. }
  131. func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
  132. return db.PrepareCtx(context.Background(), query)
  133. }
  134. func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
  135. err = db.brk.DoWithAcceptable(func() error {
  136. var conn *sql.DB
  137. conn, err = db.connProv()
  138. if err != nil {
  139. db.onError(err)
  140. return err
  141. }
  142. st, err := conn.PrepareContext(ctx, query)
  143. if err != nil {
  144. return err
  145. }
  146. stmt = statement{
  147. query: query,
  148. stmt: st,
  149. }
  150. return nil
  151. }, db.acceptable)
  152. return
  153. }
  154. func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
  155. return db.QueryRowCtx(context.Background(), v, q, args...)
  156. }
  157. func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v interface{}, q string,
  158. args ...interface{}) error {
  159. return db.queryRows(ctx, func(rows *sql.Rows) error {
  160. return unmarshalRow(v, rows, true)
  161. }, q, args...)
  162. }
  163. func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  164. return db.QueryRowPartialCtx(context.Background(), v, q, args...)
  165. }
  166. func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v interface{},
  167. q string, args ...interface{}) error {
  168. return db.queryRows(ctx, func(rows *sql.Rows) error {
  169. return unmarshalRow(v, rows, false)
  170. }, q, args...)
  171. }
  172. func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
  173. return db.QueryRowsCtx(context.Background(), v, q, args...)
  174. }
  175. func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v interface{}, q string,
  176. args ...interface{}) error {
  177. return db.queryRows(ctx, func(rows *sql.Rows) error {
  178. return unmarshalRows(v, rows, true)
  179. }, q, args...)
  180. }
  181. func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  182. return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
  183. }
  184. func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{},
  185. q string, args ...interface{}) error {
  186. return db.queryRows(ctx, func(rows *sql.Rows) error {
  187. return unmarshalRows(v, rows, false)
  188. }, q, args...)
  189. }
  190. func (db *commonSqlConn) RawDB() (*sql.DB, error) {
  191. return db.connProv()
  192. }
  193. func (db *commonSqlConn) Transact(fn func(Session) error) error {
  194. return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
  195. return fn(session)
  196. })
  197. }
  198. func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
  199. return db.brk.DoWithAcceptable(func() error {
  200. return transact(ctx, db, db.beginTx, fn)
  201. }, db.acceptable)
  202. }
  203. func (db *commonSqlConn) acceptable(err error) bool {
  204. ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
  205. if db.accept == nil {
  206. return ok
  207. }
  208. return ok || db.accept(err)
  209. }
  210. func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
  211. q string, args ...interface{}) error {
  212. var qerr error
  213. return db.brk.DoWithAcceptable(func() error {
  214. conn, err := db.connProv()
  215. if err != nil {
  216. db.onError(err)
  217. return err
  218. }
  219. return query(ctx, conn, func(rows *sql.Rows) error {
  220. qerr = scanner(rows)
  221. return qerr
  222. }, q, args...)
  223. }, func(err error) bool {
  224. return qerr == err || db.acceptable(err)
  225. })
  226. }
  227. func (s statement) Close() error {
  228. return s.stmt.Close()
  229. }
  230. func (s statement) Exec(args ...interface{}) (sql.Result, error) {
  231. return s.ExecCtx(context.Background(), args...)
  232. }
  233. func (s statement) ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error) {
  234. return execStmt(ctx, s.stmt, s.query, args...)
  235. }
  236. func (s statement) QueryRow(v interface{}, args ...interface{}) error {
  237. return s.QueryRowCtx(context.Background(), v, args...)
  238. }
  239. func (s statement) QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  240. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  241. return unmarshalRow(v, rows, true)
  242. }, s.query, args...)
  243. }
  244. func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
  245. return s.QueryRowPartialCtx(context.Background(), v, args...)
  246. }
  247. func (s statement) QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  248. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  249. return unmarshalRow(v, rows, false)
  250. }, s.query, args...)
  251. }
  252. func (s statement) QueryRows(v interface{}, args ...interface{}) error {
  253. return s.QueryRowsCtx(context.Background(), v, args...)
  254. }
  255. func (s statement) QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  256. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  257. return unmarshalRows(v, rows, true)
  258. }, s.query, args...)
  259. }
  260. func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
  261. return s.QueryRowsPartialCtx(context.Background(), v, args...)
  262. }
  263. func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  264. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  265. return unmarshalRows(v, rows, false)
  266. }, s.query, args...)
  267. }