sqlconn.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "github.com/zeromicro/go-zero/core/breaker"
  6. "github.com/zeromicro/go-zero/core/logx"
  7. )
  8. // spanName is used to identify the span name for the SQL execution.
  9. const spanName = "sql"
  10. // ErrNotFound is an alias of sql.ErrNoRows
  11. var ErrNotFound = sql.ErrNoRows
  12. type (
  13. // Session stands for raw connections or transaction sessions
  14. Session interface {
  15. Exec(query string, args ...any) (sql.Result, error)
  16. ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error)
  17. Prepare(query string) (StmtSession, error)
  18. PrepareCtx(ctx context.Context, query string) (StmtSession, error)
  19. QueryRow(v any, query string, args ...any) error
  20. QueryRowCtx(ctx context.Context, v any, query string, args ...any) error
  21. QueryRowPartial(v any, query string, args ...any) error
  22. QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error
  23. QueryRows(v any, query string, args ...any) error
  24. QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error
  25. QueryRowsPartial(v any, query string, args ...any) error
  26. QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error
  27. }
  28. // SqlConn only stands for raw connections, so Transact method can be called.
  29. SqlConn interface {
  30. Session
  31. // RawDB is for other ORM to operate with, use it with caution.
  32. // Notice: don't close it.
  33. RawDB() (*sql.DB, error)
  34. Transact(fn func(Session) error) error
  35. TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
  36. }
  37. // SqlOption defines the method to customize a sql connection.
  38. SqlOption func(*commonSqlConn)
  39. // StmtSession interface represents a session that can be used to execute statements.
  40. StmtSession interface {
  41. Close() error
  42. Exec(args ...any) (sql.Result, error)
  43. ExecCtx(ctx context.Context, args ...any) (sql.Result, error)
  44. QueryRow(v any, args ...any) error
  45. QueryRowCtx(ctx context.Context, v any, args ...any) error
  46. QueryRowPartial(v any, args ...any) error
  47. QueryRowPartialCtx(ctx context.Context, v any, args ...any) error
  48. QueryRows(v any, args ...any) error
  49. QueryRowsCtx(ctx context.Context, v any, args ...any) error
  50. QueryRowsPartial(v any, args ...any) error
  51. QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error
  52. }
  53. // thread-safe
  54. // Because CORBA doesn't support PREPARE, so we need to combine the
  55. // query arguments into one string and do underlying query without arguments
  56. commonSqlConn struct {
  57. connProv connProvider
  58. onError func(error)
  59. beginTx beginnable
  60. brk breaker.Breaker
  61. accept func(error) bool
  62. }
  63. connProvider func() (*sql.DB, error)
  64. sessionConn interface {
  65. Exec(query string, args ...any) (sql.Result, error)
  66. ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
  67. Query(query string, args ...any) (*sql.Rows, error)
  68. QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
  69. }
  70. statement struct {
  71. query string
  72. stmt *sql.Stmt
  73. }
  74. stmtConn interface {
  75. Exec(args ...any) (sql.Result, error)
  76. ExecContext(ctx context.Context, args ...any) (sql.Result, error)
  77. Query(args ...any) (*sql.Rows, error)
  78. QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
  79. }
  80. )
  81. // NewSqlConn returns a SqlConn with given driver name and datasource.
  82. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
  83. conn := &commonSqlConn{
  84. connProv: func() (*sql.DB, error) {
  85. return getSqlConn(driverName, datasource)
  86. },
  87. onError: func(err error) {
  88. logInstanceError(datasource, err)
  89. },
  90. beginTx: begin,
  91. brk: breaker.NewBreaker(),
  92. }
  93. for _, opt := range opts {
  94. opt(conn)
  95. }
  96. return conn
  97. }
  98. // NewSqlConnFromDB returns a SqlConn with the given sql.DB.
  99. // Use it with caution, it's provided for other ORM to interact with.
  100. func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
  101. conn := &commonSqlConn{
  102. connProv: func() (*sql.DB, error) {
  103. return db, nil
  104. },
  105. onError: func(err error) {
  106. logx.Errorf("Error on getting sql instance: %v", err)
  107. },
  108. beginTx: begin,
  109. brk: breaker.NewBreaker(),
  110. }
  111. for _, opt := range opts {
  112. opt(conn)
  113. }
  114. return conn
  115. }
  116. func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
  117. return db.ExecCtx(context.Background(), q, args...)
  118. }
  119. func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
  120. result sql.Result, err error) {
  121. ctx, span := startSpan(ctx, "Exec")
  122. defer func() {
  123. endSpan(span, err)
  124. }()
  125. err = db.brk.DoWithAcceptable(func() error {
  126. var conn *sql.DB
  127. conn, err = db.connProv()
  128. if err != nil {
  129. db.onError(err)
  130. return err
  131. }
  132. result, err = exec(ctx, conn, q, args...)
  133. return err
  134. }, db.acceptable)
  135. if err == breaker.ErrServiceUnavailable {
  136. metricReqErr.Inc("Exec", "breaker")
  137. }
  138. return
  139. }
  140. func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
  141. return db.PrepareCtx(context.Background(), query)
  142. }
  143. func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
  144. ctx, span := startSpan(ctx, "Prepare")
  145. defer func() {
  146. endSpan(span, err)
  147. }()
  148. err = db.brk.DoWithAcceptable(func() error {
  149. var conn *sql.DB
  150. conn, err = db.connProv()
  151. if err != nil {
  152. db.onError(err)
  153. return err
  154. }
  155. st, err := conn.PrepareContext(ctx, query)
  156. if err != nil {
  157. return err
  158. }
  159. stmt = statement{
  160. query: query,
  161. stmt: st,
  162. }
  163. return nil
  164. }, db.acceptable)
  165. if err == breaker.ErrServiceUnavailable {
  166. metricReqErr.Inc("Prepare", "breaker")
  167. }
  168. return
  169. }
  170. func (db *commonSqlConn) QueryRow(v any, q string, args ...any) error {
  171. return db.QueryRowCtx(context.Background(), v, q, args...)
  172. }
  173. func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, q string,
  174. args ...any) (err error) {
  175. ctx, span := startSpan(ctx, "QueryRow")
  176. defer func() {
  177. endSpan(span, err)
  178. }()
  179. return db.queryRows(ctx, func(rows *sql.Rows) error {
  180. return unmarshalRow(v, rows, true)
  181. }, q, args...)
  182. }
  183. func (db *commonSqlConn) QueryRowPartial(v any, q string, args ...any) error {
  184. return db.QueryRowPartialCtx(context.Background(), v, q, args...)
  185. }
  186. func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any,
  187. q string, args ...any) (err error) {
  188. ctx, span := startSpan(ctx, "QueryRowPartial")
  189. defer func() {
  190. endSpan(span, err)
  191. }()
  192. return db.queryRows(ctx, func(rows *sql.Rows) error {
  193. return unmarshalRow(v, rows, false)
  194. }, q, args...)
  195. }
  196. func (db *commonSqlConn) QueryRows(v any, q string, args ...any) error {
  197. return db.QueryRowsCtx(context.Background(), v, q, args...)
  198. }
  199. func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, q string,
  200. args ...any) (err error) {
  201. ctx, span := startSpan(ctx, "QueryRows")
  202. defer func() {
  203. endSpan(span, err)
  204. }()
  205. return db.queryRows(ctx, func(rows *sql.Rows) error {
  206. return unmarshalRows(v, rows, true)
  207. }, q, args...)
  208. }
  209. func (db *commonSqlConn) QueryRowsPartial(v any, q string, args ...any) error {
  210. return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
  211. }
  212. func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
  213. q string, args ...any) (err error) {
  214. ctx, span := startSpan(ctx, "QueryRowsPartial")
  215. defer func() {
  216. endSpan(span, err)
  217. }()
  218. return db.queryRows(ctx, func(rows *sql.Rows) error {
  219. return unmarshalRows(v, rows, false)
  220. }, q, args...)
  221. }
  222. func (db *commonSqlConn) RawDB() (*sql.DB, error) {
  223. return db.connProv()
  224. }
  225. func (db *commonSqlConn) Transact(fn func(Session) error) error {
  226. return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
  227. return fn(session)
  228. })
  229. }
  230. func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) {
  231. ctx, span := startSpan(ctx, "Transact")
  232. defer func() {
  233. endSpan(span, err)
  234. }()
  235. err = db.brk.DoWithAcceptable(func() error {
  236. return transact(ctx, db, db.beginTx, fn)
  237. }, db.acceptable)
  238. if err == breaker.ErrServiceUnavailable {
  239. metricReqErr.Inc("Transact", "breaker")
  240. }
  241. return
  242. }
  243. func (db *commonSqlConn) acceptable(err error) bool {
  244. ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
  245. if db.accept == nil {
  246. return ok
  247. }
  248. return ok || db.accept(err)
  249. }
  250. func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
  251. q string, args ...any) (err error) {
  252. var qerr error
  253. err = db.brk.DoWithAcceptable(func() error {
  254. conn, err := db.connProv()
  255. if err != nil {
  256. db.onError(err)
  257. return err
  258. }
  259. return query(ctx, conn, func(rows *sql.Rows) error {
  260. qerr = scanner(rows)
  261. return qerr
  262. }, q, args...)
  263. }, func(err error) bool {
  264. return qerr == err || db.acceptable(err)
  265. })
  266. if err == breaker.ErrServiceUnavailable {
  267. metricReqErr.Inc("queryRows", "breaker")
  268. }
  269. return
  270. }
  271. func (s statement) Close() error {
  272. return s.stmt.Close()
  273. }
  274. func (s statement) Exec(args ...any) (sql.Result, error) {
  275. return s.ExecCtx(context.Background(), args...)
  276. }
  277. func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) {
  278. ctx, span := startSpan(ctx, "Exec")
  279. defer func() {
  280. endSpan(span, err)
  281. }()
  282. return execStmt(ctx, s.stmt, s.query, args...)
  283. }
  284. func (s statement) QueryRow(v any, args ...any) error {
  285. return s.QueryRowCtx(context.Background(), v, args...)
  286. }
  287. func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) {
  288. ctx, span := startSpan(ctx, "QueryRow")
  289. defer func() {
  290. endSpan(span, err)
  291. }()
  292. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  293. return unmarshalRow(v, rows, true)
  294. }, s.query, args...)
  295. }
  296. func (s statement) QueryRowPartial(v any, args ...any) error {
  297. return s.QueryRowPartialCtx(context.Background(), v, args...)
  298. }
  299. func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) {
  300. ctx, span := startSpan(ctx, "QueryRowPartial")
  301. defer func() {
  302. endSpan(span, err)
  303. }()
  304. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  305. return unmarshalRow(v, rows, false)
  306. }, s.query, args...)
  307. }
  308. func (s statement) QueryRows(v any, args ...any) error {
  309. return s.QueryRowsCtx(context.Background(), v, args...)
  310. }
  311. func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) {
  312. ctx, span := startSpan(ctx, "QueryRows")
  313. defer func() {
  314. endSpan(span, err)
  315. }()
  316. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  317. return unmarshalRows(v, rows, true)
  318. }, s.query, args...)
  319. }
  320. func (s statement) QueryRowsPartial(v any, args ...any) error {
  321. return s.QueryRowsPartialCtx(context.Background(), v, args...)
  322. }
  323. func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) {
  324. ctx, span := startSpan(ctx, "QueryRowsPartial")
  325. defer func() {
  326. endSpan(span, err)
  327. }()
  328. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  329. return unmarshalRows(v, rows, false)
  330. }, s.query, args...)
  331. }