sqlconn.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "github.com/zeromicro/go-zero/core/breaker"
  6. "github.com/zeromicro/go-zero/core/logx"
  7. )
  8. // spanName is used to identify the span name for the SQL execution.
  9. const spanName = "sql"
  10. type (
  11. // Session stands for raw connections or transaction sessions
  12. Session interface {
  13. Exec(query string, args ...any) (sql.Result, error)
  14. ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error)
  15. Prepare(query string) (StmtSession, error)
  16. PrepareCtx(ctx context.Context, query string) (StmtSession, error)
  17. QueryRow(v any, query string, args ...any) error
  18. QueryRowCtx(ctx context.Context, v any, query string, args ...any) error
  19. QueryRowPartial(v any, query string, args ...any) error
  20. QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error
  21. QueryRows(v any, query string, args ...any) error
  22. QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error
  23. QueryRowsPartial(v any, query string, args ...any) error
  24. QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error
  25. }
  26. // SqlConn only stands for raw connections, so Transact method can be called.
  27. SqlConn interface {
  28. Session
  29. // RawDB is for other ORM to operate with, use it with caution.
  30. // Notice: don't close it.
  31. RawDB() (*sql.DB, error)
  32. Transact(fn func(Session) error) error
  33. TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
  34. }
  35. // SqlOption defines the method to customize a sql connection.
  36. SqlOption func(*commonSqlConn)
  37. // StmtSession interface represents a session that can be used to execute statements.
  38. StmtSession interface {
  39. Close() error
  40. Exec(args ...any) (sql.Result, error)
  41. ExecCtx(ctx context.Context, args ...any) (sql.Result, error)
  42. QueryRow(v any, args ...any) error
  43. QueryRowCtx(ctx context.Context, v any, args ...any) error
  44. QueryRowPartial(v any, args ...any) error
  45. QueryRowPartialCtx(ctx context.Context, v any, args ...any) error
  46. QueryRows(v any, args ...any) error
  47. QueryRowsCtx(ctx context.Context, v any, args ...any) error
  48. QueryRowsPartial(v any, args ...any) error
  49. QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error
  50. }
  51. // thread-safe
  52. // Because CORBA doesn't support PREPARE, so we need to combine the
  53. // query arguments into one string and do underlying query without arguments
  54. commonSqlConn struct {
  55. connProv connProvider
  56. onError func(context.Context, error)
  57. beginTx beginnable
  58. brk breaker.Breaker
  59. accept func(error) bool
  60. }
  61. connProvider func() (*sql.DB, error)
  62. sessionConn interface {
  63. Exec(query string, args ...any) (sql.Result, error)
  64. ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
  65. Query(query string, args ...any) (*sql.Rows, error)
  66. QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
  67. }
  68. statement struct {
  69. query string
  70. stmt *sql.Stmt
  71. }
  72. stmtConn interface {
  73. Exec(args ...any) (sql.Result, error)
  74. ExecContext(ctx context.Context, args ...any) (sql.Result, error)
  75. Query(args ...any) (*sql.Rows, error)
  76. QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
  77. }
  78. )
  79. // NewSqlConn returns a SqlConn with given driver name and datasource.
  80. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
  81. conn := &commonSqlConn{
  82. connProv: func() (*sql.DB, error) {
  83. return getSqlConn(driverName, datasource)
  84. },
  85. onError: func(ctx context.Context, err error) {
  86. logInstanceError(ctx, datasource, err)
  87. },
  88. beginTx: begin,
  89. brk: breaker.NewBreaker(),
  90. }
  91. for _, opt := range opts {
  92. opt(conn)
  93. }
  94. return conn
  95. }
  96. // NewSqlConnFromDB returns a SqlConn with the given sql.DB.
  97. // Use it with caution, it's provided for other ORM to interact with.
  98. func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
  99. conn := &commonSqlConn{
  100. connProv: func() (*sql.DB, error) {
  101. return db, nil
  102. },
  103. onError: func(ctx context.Context, err error) {
  104. logx.WithContext(ctx).Errorf("Error on getting sql instance: %v", err)
  105. },
  106. beginTx: begin,
  107. brk: breaker.NewBreaker(),
  108. }
  109. for _, opt := range opts {
  110. opt(conn)
  111. }
  112. return conn
  113. }
  114. // NewSqlConnFromSession returns a SqlConn with the given session.
  115. func NewSqlConnFromSession(session Session) SqlConn {
  116. return txConn{
  117. Session: session,
  118. }
  119. }
  120. func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
  121. return db.ExecCtx(context.Background(), q, args...)
  122. }
  123. func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
  124. result sql.Result, err error) {
  125. ctx, span := startSpan(ctx, "Exec")
  126. defer func() {
  127. endSpan(span, err)
  128. }()
  129. err = db.brk.DoWithAcceptable(func() error {
  130. var conn *sql.DB
  131. conn, err = db.connProv()
  132. if err != nil {
  133. db.onError(ctx, err)
  134. return err
  135. }
  136. result, err = exec(ctx, conn, q, args...)
  137. return err
  138. }, db.acceptable)
  139. if err == breaker.ErrServiceUnavailable {
  140. metricReqErr.Inc("Exec", "breaker")
  141. }
  142. return
  143. }
  144. func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
  145. return db.PrepareCtx(context.Background(), query)
  146. }
  147. func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
  148. ctx, span := startSpan(ctx, "Prepare")
  149. defer func() {
  150. endSpan(span, err)
  151. }()
  152. err = db.brk.DoWithAcceptable(func() error {
  153. var conn *sql.DB
  154. conn, err = db.connProv()
  155. if err != nil {
  156. db.onError(ctx, err)
  157. return err
  158. }
  159. st, err := conn.PrepareContext(ctx, query)
  160. if err != nil {
  161. return err
  162. }
  163. stmt = statement{
  164. query: query,
  165. stmt: st,
  166. }
  167. return nil
  168. }, db.acceptable)
  169. if err == breaker.ErrServiceUnavailable {
  170. metricReqErr.Inc("Prepare", "breaker")
  171. }
  172. return
  173. }
  174. func (db *commonSqlConn) QueryRow(v any, q string, args ...any) error {
  175. return db.QueryRowCtx(context.Background(), v, q, args...)
  176. }
  177. func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, q string,
  178. args ...any) (err error) {
  179. ctx, span := startSpan(ctx, "QueryRow")
  180. defer func() {
  181. endSpan(span, err)
  182. }()
  183. return db.queryRows(ctx, func(rows *sql.Rows) error {
  184. return unmarshalRow(v, rows, true)
  185. }, q, args...)
  186. }
  187. func (db *commonSqlConn) QueryRowPartial(v any, q string, args ...any) error {
  188. return db.QueryRowPartialCtx(context.Background(), v, q, args...)
  189. }
  190. func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any,
  191. q string, args ...any) (err error) {
  192. ctx, span := startSpan(ctx, "QueryRowPartial")
  193. defer func() {
  194. endSpan(span, err)
  195. }()
  196. return db.queryRows(ctx, func(rows *sql.Rows) error {
  197. return unmarshalRow(v, rows, false)
  198. }, q, args...)
  199. }
  200. func (db *commonSqlConn) QueryRows(v any, q string, args ...any) error {
  201. return db.QueryRowsCtx(context.Background(), v, q, args...)
  202. }
  203. func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, q string,
  204. args ...any) (err error) {
  205. ctx, span := startSpan(ctx, "QueryRows")
  206. defer func() {
  207. endSpan(span, err)
  208. }()
  209. return db.queryRows(ctx, func(rows *sql.Rows) error {
  210. return unmarshalRows(v, rows, true)
  211. }, q, args...)
  212. }
  213. func (db *commonSqlConn) QueryRowsPartial(v any, q string, args ...any) error {
  214. return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
  215. }
  216. func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
  217. q string, args ...any) (err error) {
  218. ctx, span := startSpan(ctx, "QueryRowsPartial")
  219. defer func() {
  220. endSpan(span, err)
  221. }()
  222. return db.queryRows(ctx, func(rows *sql.Rows) error {
  223. return unmarshalRows(v, rows, false)
  224. }, q, args...)
  225. }
  226. func (db *commonSqlConn) RawDB() (*sql.DB, error) {
  227. return db.connProv()
  228. }
  229. func (db *commonSqlConn) Transact(fn func(Session) error) error {
  230. return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
  231. return fn(session)
  232. })
  233. }
  234. func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) {
  235. ctx, span := startSpan(ctx, "Transact")
  236. defer func() {
  237. endSpan(span, err)
  238. }()
  239. err = db.brk.DoWithAcceptable(func() error {
  240. return transact(ctx, db, db.beginTx, fn)
  241. }, db.acceptable)
  242. if err == breaker.ErrServiceUnavailable {
  243. metricReqErr.Inc("Transact", "breaker")
  244. }
  245. return
  246. }
  247. func (db *commonSqlConn) acceptable(err error) bool {
  248. ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
  249. if db.accept == nil {
  250. return ok
  251. }
  252. return ok || db.accept(err)
  253. }
  254. func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
  255. q string, args ...any) (err error) {
  256. var qerr error
  257. err = db.brk.DoWithAcceptable(func() error {
  258. conn, err := db.connProv()
  259. if err != nil {
  260. db.onError(ctx, err)
  261. return err
  262. }
  263. return query(ctx, conn, func(rows *sql.Rows) error {
  264. qerr = scanner(rows)
  265. return qerr
  266. }, q, args...)
  267. }, func(err error) bool {
  268. return qerr == err || db.acceptable(err)
  269. })
  270. if err == breaker.ErrServiceUnavailable {
  271. metricReqErr.Inc("queryRows", "breaker")
  272. }
  273. return
  274. }
  275. func (s statement) Close() error {
  276. return s.stmt.Close()
  277. }
  278. func (s statement) Exec(args ...any) (sql.Result, error) {
  279. return s.ExecCtx(context.Background(), args...)
  280. }
  281. func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) {
  282. ctx, span := startSpan(ctx, "Exec")
  283. defer func() {
  284. endSpan(span, err)
  285. }()
  286. return execStmt(ctx, s.stmt, s.query, args...)
  287. }
  288. func (s statement) QueryRow(v any, args ...any) error {
  289. return s.QueryRowCtx(context.Background(), v, args...)
  290. }
  291. func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) {
  292. ctx, span := startSpan(ctx, "QueryRow")
  293. defer func() {
  294. endSpan(span, err)
  295. }()
  296. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  297. return unmarshalRow(v, rows, true)
  298. }, s.query, args...)
  299. }
  300. func (s statement) QueryRowPartial(v any, args ...any) error {
  301. return s.QueryRowPartialCtx(context.Background(), v, args...)
  302. }
  303. func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) {
  304. ctx, span := startSpan(ctx, "QueryRowPartial")
  305. defer func() {
  306. endSpan(span, err)
  307. }()
  308. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  309. return unmarshalRow(v, rows, false)
  310. }, s.query, args...)
  311. }
  312. func (s statement) QueryRows(v any, args ...any) error {
  313. return s.QueryRowsCtx(context.Background(), v, args...)
  314. }
  315. func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) {
  316. ctx, span := startSpan(ctx, "QueryRows")
  317. defer func() {
  318. endSpan(span, err)
  319. }()
  320. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  321. return unmarshalRows(v, rows, true)
  322. }, s.query, args...)
  323. }
  324. func (s statement) QueryRowsPartial(v any, args ...any) error {
  325. return s.QueryRowsPartialCtx(context.Background(), v, args...)
  326. }
  327. func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) {
  328. ctx, span := startSpan(ctx, "QueryRowsPartial")
  329. defer func() {
  330. endSpan(span, err)
  331. }()
  332. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  333. return unmarshalRows(v, rows, false)
  334. }, s.query, args...)
  335. }