sqlconn.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "github.com/zeromicro/go-zero/core/breaker"
  6. "github.com/zeromicro/go-zero/core/logx"
  7. )
  8. // spanName is used to identify the span name for the SQL execution.
  9. const spanName = "sql"
  10. // ErrNotFound is an alias of sql.ErrNoRows
  11. var ErrNotFound = sql.ErrNoRows
  12. type (
  13. // Session stands for raw connections or transaction sessions
  14. Session interface {
  15. Exec(query string, args ...interface{}) (sql.Result, error)
  16. ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
  17. Prepare(query string) (StmtSession, error)
  18. PrepareCtx(ctx context.Context, query string) (StmtSession, error)
  19. QueryRow(v interface{}, query string, args ...interface{}) error
  20. QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  21. QueryRowPartial(v interface{}, query string, args ...interface{}) error
  22. QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  23. QueryRows(v interface{}, query string, args ...interface{}) error
  24. QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  25. QueryRowsPartial(v interface{}, query string, args ...interface{}) error
  26. QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  27. }
  28. // SqlConn only stands for raw connections, so Transact method can be called.
  29. SqlConn interface {
  30. Session
  31. // RawDB is for other ORM to operate with, use it with caution.
  32. // Notice: don't close it.
  33. RawDB() (*sql.DB, error)
  34. Transact(fn func(Session) error) error
  35. TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
  36. }
  37. // SqlOption defines the method to customize a sql connection.
  38. SqlOption func(*commonSqlConn)
  39. // StmtSession interface represents a session that can be used to execute statements.
  40. StmtSession interface {
  41. Close() error
  42. Exec(args ...interface{}) (sql.Result, error)
  43. ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error)
  44. QueryRow(v interface{}, args ...interface{}) error
  45. QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error
  46. QueryRowPartial(v interface{}, args ...interface{}) error
  47. QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
  48. QueryRows(v interface{}, args ...interface{}) error
  49. QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error
  50. QueryRowsPartial(v interface{}, args ...interface{}) error
  51. QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
  52. }
  53. // thread-safe
  54. // Because CORBA doesn't support PREPARE, so we need to combine the
  55. // query arguments into one string and do underlying query without arguments
  56. commonSqlConn struct {
  57. connProv connProvider
  58. onError func(error)
  59. beginTx beginnable
  60. brk breaker.Breaker
  61. accept func(error) bool
  62. }
  63. connProvider func() (*sql.DB, error)
  64. sessionConn interface {
  65. Exec(query string, args ...interface{}) (sql.Result, error)
  66. ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
  67. Query(query string, args ...interface{}) (*sql.Rows, error)
  68. QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
  69. }
  70. statement struct {
  71. query string
  72. stmt *sql.Stmt
  73. }
  74. stmtConn interface {
  75. Exec(args ...interface{}) (sql.Result, error)
  76. ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
  77. Query(args ...interface{}) (*sql.Rows, error)
  78. QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
  79. }
  80. )
  81. // NewSqlConn returns a SqlConn with given driver name and datasource.
  82. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
  83. conn := &commonSqlConn{
  84. connProv: func() (*sql.DB, error) {
  85. return getSqlConn(driverName, datasource)
  86. },
  87. onError: func(err error) {
  88. logInstanceError(datasource, err)
  89. },
  90. beginTx: begin,
  91. brk: breaker.NewBreaker(),
  92. }
  93. for _, opt := range opts {
  94. opt(conn)
  95. }
  96. return conn
  97. }
  98. // NewSqlConnFromDB returns a SqlConn with the given sql.DB.
  99. // Use it with caution, it's provided for other ORM to interact with.
  100. func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
  101. conn := &commonSqlConn{
  102. connProv: func() (*sql.DB, error) {
  103. return db, nil
  104. },
  105. onError: func(err error) {
  106. logx.Errorf("Error on getting sql instance: %v", err)
  107. },
  108. beginTx: begin,
  109. brk: breaker.NewBreaker(),
  110. }
  111. for _, opt := range opts {
  112. opt(conn)
  113. }
  114. return conn
  115. }
  116. func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
  117. return db.ExecCtx(context.Background(), q, args...)
  118. }
  119. func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...interface{}) (
  120. result sql.Result, err error) {
  121. ctx, span := startSpan(ctx, "Exec")
  122. defer func() {
  123. endSpan(span, err)
  124. }()
  125. err = db.brk.DoWithAcceptable(func() error {
  126. var conn *sql.DB
  127. conn, err = db.connProv()
  128. if err != nil {
  129. db.onError(err)
  130. return err
  131. }
  132. result, err = exec(ctx, conn, q, args...)
  133. return err
  134. }, db.acceptable)
  135. if err == breaker.ErrServiceUnavailable {
  136. metricReqErr.Inc("Exec", "breaker")
  137. }
  138. return
  139. }
  140. func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
  141. return db.PrepareCtx(context.Background(), query)
  142. }
  143. func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
  144. ctx, span := startSpan(ctx, "Prepare")
  145. defer func() {
  146. endSpan(span, err)
  147. }()
  148. err = db.brk.DoWithAcceptable(func() error {
  149. var conn *sql.DB
  150. conn, err = db.connProv()
  151. if err != nil {
  152. db.onError(err)
  153. return err
  154. }
  155. st, err := conn.PrepareContext(ctx, query)
  156. if err != nil {
  157. return err
  158. }
  159. stmt = statement{
  160. query: query,
  161. stmt: st,
  162. }
  163. return nil
  164. }, db.acceptable)
  165. if err == breaker.ErrServiceUnavailable {
  166. metricReqErr.Inc("Prepare", "breaker")
  167. }
  168. return
  169. }
  170. func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
  171. return db.QueryRowCtx(context.Background(), v, q, args...)
  172. }
  173. func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v interface{}, q string,
  174. args ...interface{}) (err error) {
  175. ctx, span := startSpan(ctx, "QueryRow")
  176. defer func() {
  177. endSpan(span, err)
  178. }()
  179. return db.queryRows(ctx, func(rows *sql.Rows) error {
  180. return unmarshalRow(v, rows, true)
  181. }, q, args...)
  182. }
  183. func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  184. return db.QueryRowPartialCtx(context.Background(), v, q, args...)
  185. }
  186. func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v interface{},
  187. q string, args ...interface{}) (err error) {
  188. ctx, span := startSpan(ctx, "QueryRowPartial")
  189. defer func() {
  190. endSpan(span, err)
  191. }()
  192. return db.queryRows(ctx, func(rows *sql.Rows) error {
  193. return unmarshalRow(v, rows, false)
  194. }, q, args...)
  195. }
  196. func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
  197. return db.QueryRowsCtx(context.Background(), v, q, args...)
  198. }
  199. func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v interface{}, q string,
  200. args ...interface{}) (err error) {
  201. ctx, span := startSpan(ctx, "QueryRows")
  202. defer func() {
  203. endSpan(span, err)
  204. }()
  205. return db.queryRows(ctx, func(rows *sql.Rows) error {
  206. return unmarshalRows(v, rows, true)
  207. }, q, args...)
  208. }
  209. func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  210. return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
  211. }
  212. func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{},
  213. q string, args ...interface{}) (err error) {
  214. ctx, span := startSpan(ctx, "QueryRowsPartial")
  215. defer func() {
  216. endSpan(span, err)
  217. }()
  218. return db.queryRows(ctx, func(rows *sql.Rows) error {
  219. return unmarshalRows(v, rows, false)
  220. }, q, args...)
  221. }
  222. func (db *commonSqlConn) RawDB() (*sql.DB, error) {
  223. return db.connProv()
  224. }
  225. func (db *commonSqlConn) Transact(fn func(Session) error) error {
  226. return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
  227. return fn(session)
  228. })
  229. }
  230. func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) {
  231. ctx, span := startSpan(ctx, "Transact")
  232. defer func() {
  233. endSpan(span, err)
  234. }()
  235. err = db.brk.DoWithAcceptable(func() error {
  236. return transact(ctx, db, db.beginTx, fn)
  237. }, db.acceptable)
  238. if err == breaker.ErrServiceUnavailable {
  239. metricReqErr.Inc("Transact", "breaker")
  240. }
  241. return
  242. }
  243. func (db *commonSqlConn) acceptable(err error) bool {
  244. ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
  245. if db.accept == nil {
  246. return ok
  247. }
  248. return ok || db.accept(err)
  249. }
  250. func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
  251. q string, args ...interface{}) (err error) {
  252. var qerr error
  253. err = db.brk.DoWithAcceptable(func() error {
  254. conn, err := db.connProv()
  255. if err != nil {
  256. db.onError(err)
  257. return err
  258. }
  259. return query(ctx, conn, func(rows *sql.Rows) error {
  260. qerr = scanner(rows)
  261. return qerr
  262. }, q, args...)
  263. }, func(err error) bool {
  264. return qerr == err || db.acceptable(err)
  265. })
  266. if err == breaker.ErrServiceUnavailable {
  267. metricReqErr.Inc("queryRows", "breaker")
  268. }
  269. return
  270. }
  271. func (s statement) Close() error {
  272. return s.stmt.Close()
  273. }
  274. func (s statement) Exec(args ...interface{}) (sql.Result, error) {
  275. return s.ExecCtx(context.Background(), args...)
  276. }
  277. func (s statement) ExecCtx(ctx context.Context, args ...interface{}) (result sql.Result, err error) {
  278. ctx, span := startSpan(ctx, "Exec")
  279. defer func() {
  280. endSpan(span, err)
  281. }()
  282. return execStmt(ctx, s.stmt, s.query, args...)
  283. }
  284. func (s statement) QueryRow(v interface{}, args ...interface{}) error {
  285. return s.QueryRowCtx(context.Background(), v, args...)
  286. }
  287. func (s statement) QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) (err error) {
  288. ctx, span := startSpan(ctx, "QueryRow")
  289. defer func() {
  290. endSpan(span, err)
  291. }()
  292. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  293. return unmarshalRow(v, rows, true)
  294. }, s.query, args...)
  295. }
  296. func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
  297. return s.QueryRowPartialCtx(context.Background(), v, args...)
  298. }
  299. func (s statement) QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) (err error) {
  300. ctx, span := startSpan(ctx, "QueryRowPartial")
  301. defer func() {
  302. endSpan(span, err)
  303. }()
  304. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  305. return unmarshalRow(v, rows, false)
  306. }, s.query, args...)
  307. }
  308. func (s statement) QueryRows(v interface{}, args ...interface{}) error {
  309. return s.QueryRowsCtx(context.Background(), v, args...)
  310. }
  311. func (s statement) QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) (err error) {
  312. ctx, span := startSpan(ctx, "QueryRows")
  313. defer func() {
  314. endSpan(span, err)
  315. }()
  316. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  317. return unmarshalRows(v, rows, true)
  318. }, s.query, args...)
  319. }
  320. func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
  321. return s.QueryRowsPartialCtx(context.Background(), v, args...)
  322. }
  323. func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) (err error) {
  324. ctx, span := startSpan(ctx, "QueryRowsPartial")
  325. defer func() {
  326. endSpan(span, err)
  327. }()
  328. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  329. return unmarshalRows(v, rows, false)
  330. }, s.query, args...)
  331. }