sqlconn.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "github.com/wuntsong-org/go-zero-plus/core/breaker"
  7. "github.com/wuntsong-org/go-zero-plus/core/logx"
  8. )
  9. // spanName is used to identify the span name for the SQL execution.
  10. const spanName = "sql"
  11. type (
  12. // Session stands for raw connections or transaction sessions
  13. Session interface {
  14. Exec(query string, args ...any) (sql.Result, error)
  15. ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error)
  16. Prepare(query string) (StmtSession, error)
  17. PrepareCtx(ctx context.Context, query string) (StmtSession, error)
  18. QueryRow(v any, query string, args ...any) error
  19. QueryRowCtx(ctx context.Context, v any, query string, args ...any) error
  20. QueryRowPartial(v any, query string, args ...any) error
  21. QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error
  22. QueryRows(v any, query string, args ...any) error
  23. QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error
  24. QueryRowsPartial(v any, query string, args ...any) error
  25. QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error
  26. }
  27. // SqlConn only stands for raw connections, so Transact method can be called.
  28. SqlConn interface {
  29. Session
  30. // RawDB is for other ORM to operate with, use it with caution.
  31. // Notice: don't close it.
  32. RawDB() (*sql.DB, error)
  33. Transact(fn func(Session) error) error
  34. TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
  35. }
  36. // SqlOption defines the method to customize a sql connection.
  37. SqlOption func(*commonSqlConn)
  38. // StmtSession interface represents a session that can be used to execute statements.
  39. StmtSession interface {
  40. Close() error
  41. Exec(args ...any) (sql.Result, error)
  42. ExecCtx(ctx context.Context, args ...any) (sql.Result, error)
  43. QueryRow(v any, args ...any) error
  44. QueryRowCtx(ctx context.Context, v any, args ...any) error
  45. QueryRowPartial(v any, args ...any) error
  46. QueryRowPartialCtx(ctx context.Context, v any, args ...any) error
  47. QueryRows(v any, args ...any) error
  48. QueryRowsCtx(ctx context.Context, v any, args ...any) error
  49. QueryRowsPartial(v any, args ...any) error
  50. QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error
  51. }
  52. // thread-safe
  53. // Because CORBA doesn't support PREPARE, so we need to combine the
  54. // query arguments into one string and do underlying query without arguments
  55. commonSqlConn struct {
  56. connProv connProvider
  57. onError func(context.Context, error)
  58. beginTx beginnable
  59. brk breaker.Breaker
  60. accept func(error) bool
  61. }
  62. connProvider func() (*sql.DB, error)
  63. sessionConn interface {
  64. Exec(query string, args ...any) (sql.Result, error)
  65. ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
  66. Query(query string, args ...any) (*sql.Rows, error)
  67. QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
  68. }
  69. statement struct {
  70. query string
  71. stmt *sql.Stmt
  72. }
  73. stmtConn interface {
  74. Exec(args ...any) (sql.Result, error)
  75. ExecContext(ctx context.Context, args ...any) (sql.Result, error)
  76. Query(args ...any) (*sql.Rows, error)
  77. QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
  78. }
  79. )
  80. // NewSqlConn returns a SqlConn with given driver name and datasource.
  81. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
  82. conn := &commonSqlConn{
  83. connProv: func() (*sql.DB, error) {
  84. return getSqlConn(driverName, datasource)
  85. },
  86. onError: func(ctx context.Context, err error) {
  87. logInstanceError(ctx, datasource, err)
  88. },
  89. beginTx: begin,
  90. brk: breaker.NewBreaker(),
  91. }
  92. for _, opt := range opts {
  93. opt(conn)
  94. }
  95. return conn
  96. }
  97. // NewSqlConnFromDB returns a SqlConn with the given sql.DB.
  98. // Use it with caution, it's provided for other ORM to interact with.
  99. func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
  100. conn := &commonSqlConn{
  101. connProv: func() (*sql.DB, error) {
  102. return db, nil
  103. },
  104. onError: func(ctx context.Context, err error) {
  105. logx.WithContext(ctx).Errorf("Error on getting sql instance: %v", err)
  106. },
  107. beginTx: begin,
  108. brk: breaker.NewBreaker(),
  109. }
  110. for _, opt := range opts {
  111. opt(conn)
  112. }
  113. return conn
  114. }
  115. // NewSqlConnFromSession returns a SqlConn with the given session.
  116. func NewSqlConnFromSession(session Session) SqlConn {
  117. return txConn{
  118. Session: session,
  119. }
  120. }
  121. func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
  122. return db.ExecCtx(context.Background(), q, args...)
  123. }
  124. func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
  125. result sql.Result, err error) {
  126. ctx, span := startSpan(ctx, "Exec")
  127. defer func() {
  128. endSpan(span, err)
  129. }()
  130. err = db.brk.DoWithAcceptable(func() error {
  131. var conn *sql.DB
  132. conn, err = db.connProv()
  133. if err != nil {
  134. db.onError(ctx, err)
  135. return err
  136. }
  137. result, err = exec(ctx, conn, q, args...)
  138. return err
  139. }, db.acceptable)
  140. if errors.Is(err, breaker.ErrServiceUnavailable) {
  141. metricReqErr.Inc("Exec", "breaker")
  142. }
  143. return
  144. }
  145. func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
  146. return db.PrepareCtx(context.Background(), query)
  147. }
  148. func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
  149. ctx, span := startSpan(ctx, "Prepare")
  150. defer func() {
  151. endSpan(span, err)
  152. }()
  153. err = db.brk.DoWithAcceptable(func() error {
  154. var conn *sql.DB
  155. conn, err = db.connProv()
  156. if err != nil {
  157. db.onError(ctx, err)
  158. return err
  159. }
  160. st, err := conn.PrepareContext(ctx, query)
  161. if err != nil {
  162. return err
  163. }
  164. stmt = statement{
  165. query: query,
  166. stmt: st,
  167. }
  168. return nil
  169. }, db.acceptable)
  170. if errors.Is(err, breaker.ErrServiceUnavailable) {
  171. metricReqErr.Inc("Prepare", "breaker")
  172. }
  173. return
  174. }
  175. func (db *commonSqlConn) QueryRow(v any, q string, args ...any) error {
  176. return db.QueryRowCtx(context.Background(), v, q, args...)
  177. }
  178. func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, q string,
  179. args ...any) (err error) {
  180. ctx, span := startSpan(ctx, "QueryRow")
  181. defer func() {
  182. endSpan(span, err)
  183. }()
  184. return db.queryRows(ctx, func(rows *sql.Rows) error {
  185. return unmarshalRow(v, rows, true)
  186. }, q, args...)
  187. }
  188. func (db *commonSqlConn) QueryRowPartial(v any, q string, args ...any) error {
  189. return db.QueryRowPartialCtx(context.Background(), v, q, args...)
  190. }
  191. func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any,
  192. q string, args ...any) (err error) {
  193. ctx, span := startSpan(ctx, "QueryRowPartial")
  194. defer func() {
  195. endSpan(span, err)
  196. }()
  197. return db.queryRows(ctx, func(rows *sql.Rows) error {
  198. return unmarshalRow(v, rows, false)
  199. }, q, args...)
  200. }
  201. func (db *commonSqlConn) QueryRows(v any, q string, args ...any) error {
  202. return db.QueryRowsCtx(context.Background(), v, q, args...)
  203. }
  204. func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, q string,
  205. args ...any) (err error) {
  206. ctx, span := startSpan(ctx, "QueryRows")
  207. defer func() {
  208. endSpan(span, err)
  209. }()
  210. return db.queryRows(ctx, func(rows *sql.Rows) error {
  211. return unmarshalRows(v, rows, true)
  212. }, q, args...)
  213. }
  214. func (db *commonSqlConn) QueryRowsPartial(v any, q string, args ...any) error {
  215. return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
  216. }
  217. func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
  218. q string, args ...any) (err error) {
  219. ctx, span := startSpan(ctx, "QueryRowsPartial")
  220. defer func() {
  221. endSpan(span, err)
  222. }()
  223. return db.queryRows(ctx, func(rows *sql.Rows) error {
  224. return unmarshalRows(v, rows, false)
  225. }, q, args...)
  226. }
  227. func (db *commonSqlConn) RawDB() (*sql.DB, error) {
  228. return db.connProv()
  229. }
  230. func (db *commonSqlConn) Transact(fn func(Session) error) error {
  231. return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
  232. return fn(session)
  233. })
  234. }
  235. func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) {
  236. ctx, span := startSpan(ctx, "Transact")
  237. defer func() {
  238. endSpan(span, err)
  239. }()
  240. err = db.brk.DoWithAcceptable(func() error {
  241. return transact(ctx, db, db.beginTx, fn)
  242. }, db.acceptable)
  243. if errors.Is(err, breaker.ErrServiceUnavailable) {
  244. metricReqErr.Inc("Transact", "breaker")
  245. }
  246. return
  247. }
  248. func (db *commonSqlConn) acceptable(err error) bool {
  249. if err == nil || errors.Is(err, sql.ErrNoRows) || errors.Is(err, sql.ErrTxDone) ||
  250. errors.Is(err, context.Canceled) {
  251. return true
  252. }
  253. var e acceptableError
  254. if errors.As(err, &e) {
  255. return true
  256. }
  257. if db.accept == nil {
  258. return false
  259. }
  260. return db.accept(err)
  261. }
  262. func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
  263. q string, args ...any) (err error) {
  264. var qerr error
  265. err = db.brk.DoWithAcceptable(func() error {
  266. conn, err := db.connProv()
  267. if err != nil {
  268. db.onError(ctx, err)
  269. return err
  270. }
  271. return query(ctx, conn, func(rows *sql.Rows) error {
  272. qerr = scanner(rows)
  273. return qerr
  274. }, q, args...)
  275. }, func(err error) bool {
  276. return errors.Is(err, qerr) || db.acceptable(err)
  277. })
  278. if errors.Is(err, breaker.ErrServiceUnavailable) {
  279. metricReqErr.Inc("queryRows", "breaker")
  280. }
  281. return
  282. }
  283. func (s statement) Close() error {
  284. return s.stmt.Close()
  285. }
  286. func (s statement) Exec(args ...any) (sql.Result, error) {
  287. return s.ExecCtx(context.Background(), args...)
  288. }
  289. func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) {
  290. ctx, span := startSpan(ctx, "Exec")
  291. defer func() {
  292. endSpan(span, err)
  293. }()
  294. return execStmt(ctx, s.stmt, s.query, args...)
  295. }
  296. func (s statement) QueryRow(v any, args ...any) error {
  297. return s.QueryRowCtx(context.Background(), v, args...)
  298. }
  299. func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) {
  300. ctx, span := startSpan(ctx, "QueryRow")
  301. defer func() {
  302. endSpan(span, err)
  303. }()
  304. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  305. return unmarshalRow(v, rows, true)
  306. }, s.query, args...)
  307. }
  308. func (s statement) QueryRowPartial(v any, args ...any) error {
  309. return s.QueryRowPartialCtx(context.Background(), v, args...)
  310. }
  311. func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) {
  312. ctx, span := startSpan(ctx, "QueryRowPartial")
  313. defer func() {
  314. endSpan(span, err)
  315. }()
  316. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  317. return unmarshalRow(v, rows, false)
  318. }, s.query, args...)
  319. }
  320. func (s statement) QueryRows(v any, args ...any) error {
  321. return s.QueryRowsCtx(context.Background(), v, args...)
  322. }
  323. func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) {
  324. ctx, span := startSpan(ctx, "QueryRows")
  325. defer func() {
  326. endSpan(span, err)
  327. }()
  328. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  329. return unmarshalRows(v, rows, true)
  330. }, s.query, args...)
  331. }
  332. func (s statement) QueryRowsPartial(v any, args ...any) error {
  333. return s.QueryRowsPartialCtx(context.Background(), v, args...)
  334. }
  335. func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) {
  336. ctx, span := startSpan(ctx, "QueryRowsPartial")
  337. defer func() {
  338. endSpan(span, err)
  339. }()
  340. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  341. return unmarshalRows(v, rows, false)
  342. }, s.query, args...)
  343. }
  344. // WithAcceptable returns a SqlOption that setting the acceptable function.
  345. // acceptable is the func to check if the error can be accepted.
  346. func WithAcceptable(acceptable func(err error) bool) SqlOption {
  347. return func(conn *commonSqlConn) {
  348. conn.accept = acceptable
  349. }
  350. }