sqlconn.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "github.com/zeromicro/go-zero/core/breaker"
  6. "github.com/zeromicro/go-zero/core/logx"
  7. "github.com/zeromicro/go-zero/core/trace"
  8. "go.opentelemetry.io/otel"
  9. tracesdk "go.opentelemetry.io/otel/trace"
  10. )
  11. // ErrNotFound is an alias of sql.ErrNoRows
  12. var ErrNotFound = sql.ErrNoRows
  13. type (
  14. // Session stands for raw connections or transaction sessions
  15. Session interface {
  16. Exec(query string, args ...interface{}) (sql.Result, error)
  17. ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
  18. Prepare(query string) (StmtSession, error)
  19. PrepareCtx(ctx context.Context, query string) (StmtSession, error)
  20. QueryRow(v interface{}, query string, args ...interface{}) error
  21. QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  22. QueryRowPartial(v interface{}, query string, args ...interface{}) error
  23. QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  24. QueryRows(v interface{}, query string, args ...interface{}) error
  25. QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  26. QueryRowsPartial(v interface{}, query string, args ...interface{}) error
  27. QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  28. }
  29. // SqlConn only stands for raw connections, so Transact method can be called.
  30. SqlConn interface {
  31. Session
  32. // RawDB is for other ORM to operate with, use it with caution.
  33. // Notice: don't close it.
  34. RawDB() (*sql.DB, error)
  35. Transact(fn func(Session) error) error
  36. TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
  37. }
  38. // SqlOption defines the method to customize a sql connection.
  39. SqlOption func(*commonSqlConn)
  40. // StmtSession interface represents a session that can be used to execute statements.
  41. StmtSession interface {
  42. Close() error
  43. Exec(args ...interface{}) (sql.Result, error)
  44. ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error)
  45. QueryRow(v interface{}, args ...interface{}) error
  46. QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error
  47. QueryRowPartial(v interface{}, args ...interface{}) error
  48. QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
  49. QueryRows(v interface{}, args ...interface{}) error
  50. QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error
  51. QueryRowsPartial(v interface{}, args ...interface{}) error
  52. QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
  53. }
  54. // thread-safe
  55. // Because CORBA doesn't support PREPARE, so we need to combine the
  56. // query arguments into one string and do underlying query without arguments
  57. commonSqlConn struct {
  58. connProv connProvider
  59. onError func(error)
  60. beginTx beginnable
  61. brk breaker.Breaker
  62. accept func(error) bool
  63. }
  64. connProvider func() (*sql.DB, error)
  65. sessionConn interface {
  66. Exec(query string, args ...interface{}) (sql.Result, error)
  67. ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
  68. Query(query string, args ...interface{}) (*sql.Rows, error)
  69. QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
  70. }
  71. statement struct {
  72. query string
  73. stmt *sql.Stmt
  74. }
  75. stmtConn interface {
  76. Exec(args ...interface{}) (sql.Result, error)
  77. ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
  78. Query(args ...interface{}) (*sql.Rows, error)
  79. QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
  80. }
  81. )
  82. // NewSqlConn returns a SqlConn with given driver name and datasource.
  83. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
  84. conn := &commonSqlConn{
  85. connProv: func() (*sql.DB, error) {
  86. return getSqlConn(driverName, datasource)
  87. },
  88. onError: func(err error) {
  89. logInstanceError(datasource, err)
  90. },
  91. beginTx: begin,
  92. brk: breaker.NewBreaker(),
  93. }
  94. for _, opt := range opts {
  95. opt(conn)
  96. }
  97. return conn
  98. }
  99. // NewSqlConnFromDB returns a SqlConn with the given sql.DB.
  100. // Use it with caution, it's provided for other ORM to interact with.
  101. func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
  102. conn := &commonSqlConn{
  103. connProv: func() (*sql.DB, error) {
  104. return db, nil
  105. },
  106. onError: func(err error) {
  107. logx.Errorf("Error on getting sql instance: %v", err)
  108. },
  109. beginTx: begin,
  110. brk: breaker.NewBreaker(),
  111. }
  112. for _, opt := range opts {
  113. opt(conn)
  114. }
  115. return conn
  116. }
  117. func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
  118. return db.ExecCtx(context.Background(), q, args...)
  119. }
  120. func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...interface{}) (
  121. result sql.Result, err error) {
  122. ctx, span := startSpan(ctx)
  123. defer span.End()
  124. err = db.brk.DoWithAcceptable(func() error {
  125. var conn *sql.DB
  126. conn, err = db.connProv()
  127. if err != nil {
  128. db.onError(err)
  129. return err
  130. }
  131. result, err = exec(ctx, conn, q, args...)
  132. return err
  133. }, db.acceptable)
  134. return
  135. }
  136. func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
  137. return db.PrepareCtx(context.Background(), query)
  138. }
  139. func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
  140. ctx, span := startSpan(ctx)
  141. defer span.End()
  142. err = db.brk.DoWithAcceptable(func() error {
  143. var conn *sql.DB
  144. conn, err = db.connProv()
  145. if err != nil {
  146. db.onError(err)
  147. return err
  148. }
  149. st, err := conn.PrepareContext(ctx, query)
  150. if err != nil {
  151. return err
  152. }
  153. stmt = statement{
  154. query: query,
  155. stmt: st,
  156. }
  157. return nil
  158. }, db.acceptable)
  159. return
  160. }
  161. func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
  162. return db.QueryRowCtx(context.Background(), v, q, args...)
  163. }
  164. func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v interface{}, q string,
  165. args ...interface{}) error {
  166. ctx, span := startSpan(ctx)
  167. defer span.End()
  168. return db.queryRows(ctx, func(rows *sql.Rows) error {
  169. return unmarshalRow(v, rows, true)
  170. }, q, args...)
  171. }
  172. func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  173. return db.QueryRowPartialCtx(context.Background(), v, q, args...)
  174. }
  175. func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v interface{},
  176. q string, args ...interface{}) error {
  177. ctx, span := startSpan(ctx)
  178. defer span.End()
  179. return db.queryRows(ctx, func(rows *sql.Rows) error {
  180. return unmarshalRow(v, rows, false)
  181. }, q, args...)
  182. }
  183. func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
  184. return db.QueryRowsCtx(context.Background(), v, q, args...)
  185. }
  186. func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v interface{}, q string,
  187. args ...interface{}) error {
  188. ctx, span := startSpan(ctx)
  189. defer span.End()
  190. return db.queryRows(ctx, func(rows *sql.Rows) error {
  191. return unmarshalRows(v, rows, true)
  192. }, q, args...)
  193. }
  194. func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  195. return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
  196. }
  197. func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{},
  198. q string, args ...interface{}) error {
  199. ctx, span := startSpan(ctx)
  200. defer span.End()
  201. return db.queryRows(ctx, func(rows *sql.Rows) error {
  202. return unmarshalRows(v, rows, false)
  203. }, q, args...)
  204. }
  205. func (db *commonSqlConn) RawDB() (*sql.DB, error) {
  206. return db.connProv()
  207. }
  208. func (db *commonSqlConn) Transact(fn func(Session) error) error {
  209. return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
  210. return fn(session)
  211. })
  212. }
  213. func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
  214. ctx, span := startSpan(ctx)
  215. defer span.End()
  216. return db.brk.DoWithAcceptable(func() error {
  217. return transact(ctx, db, db.beginTx, fn)
  218. }, db.acceptable)
  219. }
  220. func (db *commonSqlConn) acceptable(err error) bool {
  221. ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
  222. if db.accept == nil {
  223. return ok
  224. }
  225. return ok || db.accept(err)
  226. }
  227. func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
  228. q string, args ...interface{}) error {
  229. ctx, span := startSpan(ctx)
  230. defer span.End()
  231. var qerr error
  232. return db.brk.DoWithAcceptable(func() error {
  233. conn, err := db.connProv()
  234. if err != nil {
  235. db.onError(err)
  236. return err
  237. }
  238. return query(ctx, conn, func(rows *sql.Rows) error {
  239. qerr = scanner(rows)
  240. return qerr
  241. }, q, args...)
  242. }, func(err error) bool {
  243. return qerr == err || db.acceptable(err)
  244. })
  245. }
  246. func (s statement) Close() error {
  247. return s.stmt.Close()
  248. }
  249. func (s statement) Exec(args ...interface{}) (sql.Result, error) {
  250. return s.ExecCtx(context.Background(), args...)
  251. }
  252. func (s statement) ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error) {
  253. ctx, span := startSpan(ctx)
  254. defer span.End()
  255. return execStmt(ctx, s.stmt, s.query, args...)
  256. }
  257. func (s statement) QueryRow(v interface{}, args ...interface{}) error {
  258. return s.QueryRowCtx(context.Background(), v, args...)
  259. }
  260. func (s statement) QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  261. ctx, span := startSpan(ctx)
  262. defer span.End()
  263. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  264. return unmarshalRow(v, rows, true)
  265. }, s.query, args...)
  266. }
  267. func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
  268. return s.QueryRowPartialCtx(context.Background(), v, args...)
  269. }
  270. func (s statement) QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  271. ctx, span := startSpan(ctx)
  272. defer span.End()
  273. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  274. return unmarshalRow(v, rows, false)
  275. }, s.query, args...)
  276. }
  277. func (s statement) QueryRows(v interface{}, args ...interface{}) error {
  278. return s.QueryRowsCtx(context.Background(), v, args...)
  279. }
  280. func (s statement) QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  281. ctx, span := startSpan(ctx)
  282. defer span.End()
  283. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  284. return unmarshalRows(v, rows, true)
  285. }, s.query, args...)
  286. }
  287. func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
  288. return s.QueryRowsPartialCtx(context.Background(), v, args...)
  289. }
  290. func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  291. ctx, span := startSpan(ctx)
  292. defer span.End()
  293. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  294. return unmarshalRows(v, rows, false)
  295. }, s.query, args...)
  296. }
  297. func startSpan(ctx context.Context) (context.Context, tracesdk.Span) {
  298. tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
  299. return tracer.Start(ctx, "sql")
  300. }