123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397 |
- package sqlx
- import (
- "context"
- "database/sql"
- "github.com/zeromicro/go-zero/core/breaker"
- "github.com/zeromicro/go-zero/core/logx"
- )
- // spanName is used to identify the span name for the SQL execution.
- const spanName = "sql"
- // ErrNotFound is an alias of sql.ErrNoRows
- var ErrNotFound = sql.ErrNoRows
- type (
- // Session stands for raw connections or transaction sessions
- Session interface {
- Exec(query string, args ...any) (sql.Result, error)
- ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error)
- Prepare(query string) (StmtSession, error)
- PrepareCtx(ctx context.Context, query string) (StmtSession, error)
- QueryRow(v any, query string, args ...any) error
- QueryRowCtx(ctx context.Context, v any, query string, args ...any) error
- QueryRowPartial(v any, query string, args ...any) error
- QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error
- QueryRows(v any, query string, args ...any) error
- QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error
- QueryRowsPartial(v any, query string, args ...any) error
- QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error
- }
- // SqlConn only stands for raw connections, so Transact method can be called.
- SqlConn interface {
- Session
- // RawDB is for other ORM to operate with, use it with caution.
- // Notice: don't close it.
- RawDB() (*sql.DB, error)
- Transact(fn func(Session) error) error
- TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
- }
- // SqlOption defines the method to customize a sql connection.
- SqlOption func(*commonSqlConn)
- // StmtSession interface represents a session that can be used to execute statements.
- StmtSession interface {
- Close() error
- Exec(args ...any) (sql.Result, error)
- ExecCtx(ctx context.Context, args ...any) (sql.Result, error)
- QueryRow(v any, args ...any) error
- QueryRowCtx(ctx context.Context, v any, args ...any) error
- QueryRowPartial(v any, args ...any) error
- QueryRowPartialCtx(ctx context.Context, v any, args ...any) error
- QueryRows(v any, args ...any) error
- QueryRowsCtx(ctx context.Context, v any, args ...any) error
- QueryRowsPartial(v any, args ...any) error
- QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error
- }
- // thread-safe
- // Because CORBA doesn't support PREPARE, so we need to combine the
- // query arguments into one string and do underlying query without arguments
- commonSqlConn struct {
- connProv connProvider
- onError func(error)
- beginTx beginnable
- brk breaker.Breaker
- accept func(error) bool
- }
- connProvider func() (*sql.DB, error)
- sessionConn interface {
- Exec(query string, args ...any) (sql.Result, error)
- ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
- Query(query string, args ...any) (*sql.Rows, error)
- QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
- }
- statement struct {
- query string
- stmt *sql.Stmt
- }
- stmtConn interface {
- Exec(args ...any) (sql.Result, error)
- ExecContext(ctx context.Context, args ...any) (sql.Result, error)
- Query(args ...any) (*sql.Rows, error)
- QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
- }
- )
- // NewSqlConn returns a SqlConn with given driver name and datasource.
- func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
- conn := &commonSqlConn{
- connProv: func() (*sql.DB, error) {
- return getSqlConn(driverName, datasource)
- },
- onError: func(err error) {
- logInstanceError(datasource, err)
- },
- beginTx: begin,
- brk: breaker.NewBreaker(),
- }
- for _, opt := range opts {
- opt(conn)
- }
- return conn
- }
- // NewSqlConnFromDB returns a SqlConn with the given sql.DB.
- // Use it with caution, it's provided for other ORM to interact with.
- func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
- conn := &commonSqlConn{
- connProv: func() (*sql.DB, error) {
- return db, nil
- },
- onError: func(err error) {
- logx.Errorf("Error on getting sql instance: %v", err)
- },
- beginTx: begin,
- brk: breaker.NewBreaker(),
- }
- for _, opt := range opts {
- opt(conn)
- }
- return conn
- }
- func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
- return db.ExecCtx(context.Background(), q, args...)
- }
- func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
- result sql.Result, err error) {
- ctx, span := startSpan(ctx, "Exec")
- defer func() {
- endSpan(span, err)
- }()
- err = db.brk.DoWithAcceptable(func() error {
- var conn *sql.DB
- conn, err = db.connProv()
- if err != nil {
- db.onError(err)
- return err
- }
- result, err = exec(ctx, conn, q, args...)
- return err
- }, db.acceptable)
- if err == breaker.ErrServiceUnavailable {
- metricReqErr.Inc("Exec", "breaker")
- }
- return
- }
- func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
- return db.PrepareCtx(context.Background(), query)
- }
- func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
- ctx, span := startSpan(ctx, "Prepare")
- defer func() {
- endSpan(span, err)
- }()
- err = db.brk.DoWithAcceptable(func() error {
- var conn *sql.DB
- conn, err = db.connProv()
- if err != nil {
- db.onError(err)
- return err
- }
- st, err := conn.PrepareContext(ctx, query)
- if err != nil {
- return err
- }
- stmt = statement{
- query: query,
- stmt: st,
- }
- return nil
- }, db.acceptable)
- if err == breaker.ErrServiceUnavailable {
- metricReqErr.Inc("Prepare", "breaker")
- }
- return
- }
- func (db *commonSqlConn) QueryRow(v any, q string, args ...any) error {
- return db.QueryRowCtx(context.Background(), v, q, args...)
- }
- func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, q string,
- args ...any) (err error) {
- ctx, span := startSpan(ctx, "QueryRow")
- defer func() {
- endSpan(span, err)
- }()
- return db.queryRows(ctx, func(rows *sql.Rows) error {
- return unmarshalRow(v, rows, true)
- }, q, args...)
- }
- func (db *commonSqlConn) QueryRowPartial(v any, q string, args ...any) error {
- return db.QueryRowPartialCtx(context.Background(), v, q, args...)
- }
- func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any,
- q string, args ...any) (err error) {
- ctx, span := startSpan(ctx, "QueryRowPartial")
- defer func() {
- endSpan(span, err)
- }()
- return db.queryRows(ctx, func(rows *sql.Rows) error {
- return unmarshalRow(v, rows, false)
- }, q, args...)
- }
- func (db *commonSqlConn) QueryRows(v any, q string, args ...any) error {
- return db.QueryRowsCtx(context.Background(), v, q, args...)
- }
- func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, q string,
- args ...any) (err error) {
- ctx, span := startSpan(ctx, "QueryRows")
- defer func() {
- endSpan(span, err)
- }()
- return db.queryRows(ctx, func(rows *sql.Rows) error {
- return unmarshalRows(v, rows, true)
- }, q, args...)
- }
- func (db *commonSqlConn) QueryRowsPartial(v any, q string, args ...any) error {
- return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
- }
- func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
- q string, args ...any) (err error) {
- ctx, span := startSpan(ctx, "QueryRowsPartial")
- defer func() {
- endSpan(span, err)
- }()
- return db.queryRows(ctx, func(rows *sql.Rows) error {
- return unmarshalRows(v, rows, false)
- }, q, args...)
- }
- func (db *commonSqlConn) RawDB() (*sql.DB, error) {
- return db.connProv()
- }
- func (db *commonSqlConn) Transact(fn func(Session) error) error {
- return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
- return fn(session)
- })
- }
- func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) {
- ctx, span := startSpan(ctx, "Transact")
- defer func() {
- endSpan(span, err)
- }()
- err = db.brk.DoWithAcceptable(func() error {
- return transact(ctx, db, db.beginTx, fn)
- }, db.acceptable)
- if err == breaker.ErrServiceUnavailable {
- metricReqErr.Inc("Transact", "breaker")
- }
- return
- }
- func (db *commonSqlConn) acceptable(err error) bool {
- ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
- if db.accept == nil {
- return ok
- }
- return ok || db.accept(err)
- }
- func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
- q string, args ...any) (err error) {
- var qerr error
- err = db.brk.DoWithAcceptable(func() error {
- conn, err := db.connProv()
- if err != nil {
- db.onError(err)
- return err
- }
- return query(ctx, conn, func(rows *sql.Rows) error {
- qerr = scanner(rows)
- return qerr
- }, q, args...)
- }, func(err error) bool {
- return qerr == err || db.acceptable(err)
- })
- if err == breaker.ErrServiceUnavailable {
- metricReqErr.Inc("queryRows", "breaker")
- }
- return
- }
- func (s statement) Close() error {
- return s.stmt.Close()
- }
- func (s statement) Exec(args ...any) (sql.Result, error) {
- return s.ExecCtx(context.Background(), args...)
- }
- func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) {
- ctx, span := startSpan(ctx, "Exec")
- defer func() {
- endSpan(span, err)
- }()
- return execStmt(ctx, s.stmt, s.query, args...)
- }
- func (s statement) QueryRow(v any, args ...any) error {
- return s.QueryRowCtx(context.Background(), v, args...)
- }
- func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) {
- ctx, span := startSpan(ctx, "QueryRow")
- defer func() {
- endSpan(span, err)
- }()
- return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
- return unmarshalRow(v, rows, true)
- }, s.query, args...)
- }
- func (s statement) QueryRowPartial(v any, args ...any) error {
- return s.QueryRowPartialCtx(context.Background(), v, args...)
- }
- func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) {
- ctx, span := startSpan(ctx, "QueryRowPartial")
- defer func() {
- endSpan(span, err)
- }()
- return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
- return unmarshalRow(v, rows, false)
- }, s.query, args...)
- }
- func (s statement) QueryRows(v any, args ...any) error {
- return s.QueryRowsCtx(context.Background(), v, args...)
- }
- func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) {
- ctx, span := startSpan(ctx, "QueryRows")
- defer func() {
- endSpan(span, err)
- }()
- return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
- return unmarshalRows(v, rows, true)
- }, s.query, args...)
- }
- func (s statement) QueryRowsPartial(v any, args ...any) error {
- return s.QueryRowsPartialCtx(context.Background(), v, args...)
- }
- func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) {
- ctx, span := startSpan(ctx, "QueryRowsPartial")
- defer func() {
- endSpan(span, err)
- }()
- return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
- return unmarshalRows(v, rows, false)
- }, s.query, args...)
- }
|