sqlconn.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "github.com/zeromicro/go-zero/core/breaker"
  6. "github.com/zeromicro/go-zero/core/logx"
  7. "github.com/zeromicro/go-zero/core/trace"
  8. "go.opentelemetry.io/otel"
  9. tracesdk "go.opentelemetry.io/otel/trace"
  10. )
  11. // spanName is used to identify the span name for the SQL execution.
  12. const spanName = "sql"
  13. // ErrNotFound is an alias of sql.ErrNoRows
  14. var ErrNotFound = sql.ErrNoRows
  15. type (
  16. // Session stands for raw connections or transaction sessions
  17. Session interface {
  18. Exec(query string, args ...interface{}) (sql.Result, error)
  19. ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
  20. Prepare(query string) (StmtSession, error)
  21. PrepareCtx(ctx context.Context, query string) (StmtSession, error)
  22. QueryRow(v interface{}, query string, args ...interface{}) error
  23. QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  24. QueryRowPartial(v interface{}, query string, args ...interface{}) error
  25. QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  26. QueryRows(v interface{}, query string, args ...interface{}) error
  27. QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  28. QueryRowsPartial(v interface{}, query string, args ...interface{}) error
  29. QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
  30. }
  31. // SqlConn only stands for raw connections, so Transact method can be called.
  32. SqlConn interface {
  33. Session
  34. // RawDB is for other ORM to operate with, use it with caution.
  35. // Notice: don't close it.
  36. RawDB() (*sql.DB, error)
  37. Transact(fn func(Session) error) error
  38. TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
  39. }
  40. // SqlOption defines the method to customize a sql connection.
  41. SqlOption func(*commonSqlConn)
  42. // StmtSession interface represents a session that can be used to execute statements.
  43. StmtSession interface {
  44. Close() error
  45. Exec(args ...interface{}) (sql.Result, error)
  46. ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error)
  47. QueryRow(v interface{}, args ...interface{}) error
  48. QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error
  49. QueryRowPartial(v interface{}, args ...interface{}) error
  50. QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
  51. QueryRows(v interface{}, args ...interface{}) error
  52. QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error
  53. QueryRowsPartial(v interface{}, args ...interface{}) error
  54. QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
  55. }
  56. // thread-safe
  57. // Because CORBA doesn't support PREPARE, so we need to combine the
  58. // query arguments into one string and do underlying query without arguments
  59. commonSqlConn struct {
  60. connProv connProvider
  61. onError func(error)
  62. beginTx beginnable
  63. brk breaker.Breaker
  64. accept func(error) bool
  65. }
  66. connProvider func() (*sql.DB, error)
  67. sessionConn interface {
  68. Exec(query string, args ...interface{}) (sql.Result, error)
  69. ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
  70. Query(query string, args ...interface{}) (*sql.Rows, error)
  71. QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
  72. }
  73. statement struct {
  74. query string
  75. stmt *sql.Stmt
  76. }
  77. stmtConn interface {
  78. Exec(args ...interface{}) (sql.Result, error)
  79. ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
  80. Query(args ...interface{}) (*sql.Rows, error)
  81. QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
  82. }
  83. )
  84. // NewSqlConn returns a SqlConn with given driver name and datasource.
  85. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
  86. conn := &commonSqlConn{
  87. connProv: func() (*sql.DB, error) {
  88. return getSqlConn(driverName, datasource)
  89. },
  90. onError: func(err error) {
  91. logInstanceError(datasource, err)
  92. },
  93. beginTx: begin,
  94. brk: breaker.NewBreaker(),
  95. }
  96. for _, opt := range opts {
  97. opt(conn)
  98. }
  99. return conn
  100. }
  101. // NewSqlConnFromDB returns a SqlConn with the given sql.DB.
  102. // Use it with caution, it's provided for other ORM to interact with.
  103. func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
  104. conn := &commonSqlConn{
  105. connProv: func() (*sql.DB, error) {
  106. return db, nil
  107. },
  108. onError: func(err error) {
  109. logx.Errorf("Error on getting sql instance: %v", err)
  110. },
  111. beginTx: begin,
  112. brk: breaker.NewBreaker(),
  113. }
  114. for _, opt := range opts {
  115. opt(conn)
  116. }
  117. return conn
  118. }
  119. func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
  120. return db.ExecCtx(context.Background(), q, args...)
  121. }
  122. func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...interface{}) (
  123. result sql.Result, err error) {
  124. ctx, span := startSpan(ctx)
  125. defer span.End()
  126. err = db.brk.DoWithAcceptable(func() error {
  127. var conn *sql.DB
  128. conn, err = db.connProv()
  129. if err != nil {
  130. db.onError(err)
  131. return err
  132. }
  133. result, err = exec(ctx, conn, q, args...)
  134. return err
  135. }, db.acceptable)
  136. return
  137. }
  138. func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
  139. return db.PrepareCtx(context.Background(), query)
  140. }
  141. func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
  142. ctx, span := startSpan(ctx)
  143. defer span.End()
  144. err = db.brk.DoWithAcceptable(func() error {
  145. var conn *sql.DB
  146. conn, err = db.connProv()
  147. if err != nil {
  148. db.onError(err)
  149. return err
  150. }
  151. st, err := conn.PrepareContext(ctx, query)
  152. if err != nil {
  153. return err
  154. }
  155. stmt = statement{
  156. query: query,
  157. stmt: st,
  158. }
  159. return nil
  160. }, db.acceptable)
  161. return
  162. }
  163. func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
  164. return db.QueryRowCtx(context.Background(), v, q, args...)
  165. }
  166. func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v interface{}, q string,
  167. args ...interface{}) error {
  168. ctx, span := startSpan(ctx)
  169. defer span.End()
  170. return db.queryRows(ctx, func(rows *sql.Rows) error {
  171. return unmarshalRow(v, rows, true)
  172. }, q, args...)
  173. }
  174. func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  175. return db.QueryRowPartialCtx(context.Background(), v, q, args...)
  176. }
  177. func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v interface{},
  178. q string, args ...interface{}) error {
  179. ctx, span := startSpan(ctx)
  180. defer span.End()
  181. return db.queryRows(ctx, func(rows *sql.Rows) error {
  182. return unmarshalRow(v, rows, false)
  183. }, q, args...)
  184. }
  185. func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
  186. return db.QueryRowsCtx(context.Background(), v, q, args...)
  187. }
  188. func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v interface{}, q string,
  189. args ...interface{}) error {
  190. ctx, span := startSpan(ctx)
  191. defer span.End()
  192. return db.queryRows(ctx, func(rows *sql.Rows) error {
  193. return unmarshalRows(v, rows, true)
  194. }, q, args...)
  195. }
  196. func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  197. return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
  198. }
  199. func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{},
  200. q string, args ...interface{}) error {
  201. ctx, span := startSpan(ctx)
  202. defer span.End()
  203. return db.queryRows(ctx, func(rows *sql.Rows) error {
  204. return unmarshalRows(v, rows, false)
  205. }, q, args...)
  206. }
  207. func (db *commonSqlConn) RawDB() (*sql.DB, error) {
  208. return db.connProv()
  209. }
  210. func (db *commonSqlConn) Transact(fn func(Session) error) error {
  211. return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
  212. return fn(session)
  213. })
  214. }
  215. func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
  216. ctx, span := startSpan(ctx)
  217. defer span.End()
  218. return db.brk.DoWithAcceptable(func() error {
  219. return transact(ctx, db, db.beginTx, fn)
  220. }, db.acceptable)
  221. }
  222. func (db *commonSqlConn) acceptable(err error) bool {
  223. ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
  224. if db.accept == nil {
  225. return ok
  226. }
  227. return ok || db.accept(err)
  228. }
  229. func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
  230. q string, args ...interface{}) error {
  231. ctx, span := startSpan(ctx)
  232. defer span.End()
  233. var qerr error
  234. return db.brk.DoWithAcceptable(func() error {
  235. conn, err := db.connProv()
  236. if err != nil {
  237. db.onError(err)
  238. return err
  239. }
  240. return query(ctx, conn, func(rows *sql.Rows) error {
  241. qerr = scanner(rows)
  242. return qerr
  243. }, q, args...)
  244. }, func(err error) bool {
  245. return qerr == err || db.acceptable(err)
  246. })
  247. }
  248. func (s statement) Close() error {
  249. return s.stmt.Close()
  250. }
  251. func (s statement) Exec(args ...interface{}) (sql.Result, error) {
  252. return s.ExecCtx(context.Background(), args...)
  253. }
  254. func (s statement) ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error) {
  255. ctx, span := startSpan(ctx)
  256. defer span.End()
  257. return execStmt(ctx, s.stmt, s.query, args...)
  258. }
  259. func (s statement) QueryRow(v interface{}, args ...interface{}) error {
  260. return s.QueryRowCtx(context.Background(), v, args...)
  261. }
  262. func (s statement) QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  263. ctx, span := startSpan(ctx)
  264. defer span.End()
  265. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  266. return unmarshalRow(v, rows, true)
  267. }, s.query, args...)
  268. }
  269. func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
  270. return s.QueryRowPartialCtx(context.Background(), v, args...)
  271. }
  272. func (s statement) QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  273. ctx, span := startSpan(ctx)
  274. defer span.End()
  275. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  276. return unmarshalRow(v, rows, false)
  277. }, s.query, args...)
  278. }
  279. func (s statement) QueryRows(v interface{}, args ...interface{}) error {
  280. return s.QueryRowsCtx(context.Background(), v, args...)
  281. }
  282. func (s statement) QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  283. ctx, span := startSpan(ctx)
  284. defer span.End()
  285. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  286. return unmarshalRows(v, rows, true)
  287. }, s.query, args...)
  288. }
  289. func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
  290. return s.QueryRowsPartialCtx(context.Background(), v, args...)
  291. }
  292. func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
  293. ctx, span := startSpan(ctx)
  294. defer span.End()
  295. return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
  296. return unmarshalRows(v, rows, false)
  297. }, s.query, args...)
  298. }
  299. func startSpan(ctx context.Context) (context.Context, tracesdk.Span) {
  300. tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
  301. return tracer.Start(ctx, spanName)
  302. }