sqlconn.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. // copy from core/stores/sqlx/sqlconn.go
  2. package mocksql
  3. import (
  4. "context"
  5. "database/sql"
  6. "github.com/zeromicro/go-zero/core/stores/sqlx"
  7. )
  8. type (
  9. // MockConn defines a mock connection instance for mysql
  10. MockConn struct {
  11. db *sql.DB
  12. }
  13. statement struct {
  14. stmt *sql.Stmt
  15. }
  16. )
  17. // NewMockConn creates an instance for MockConn
  18. func NewMockConn(db *sql.DB) *MockConn {
  19. return &MockConn{db: db}
  20. }
  21. // Exec executes sql and returns the result
  22. func (conn *MockConn) Exec(query string, args ...any) (sql.Result, error) {
  23. return exec(conn.db, query, args...)
  24. }
  25. // ExecCtx executes sql and returns the result
  26. func (conn *MockConn) ExecCtx(_ context.Context, query string, args ...any) (sql.Result, error) {
  27. return exec(conn.db, query, args...)
  28. }
  29. // Prepare executes sql by sql.DB
  30. func (conn *MockConn) Prepare(query string) (sqlx.StmtSession, error) {
  31. st, err := conn.db.Prepare(query)
  32. return statement{stmt: st}, err
  33. }
  34. // PrepareCtx executes sql by sql.DB
  35. func (conn *MockConn) PrepareCtx(_ context.Context, query string) (sqlx.StmtSession, error) {
  36. return conn.Prepare(query)
  37. }
  38. // QueryRow executes sql and returns a query row
  39. func (conn *MockConn) QueryRow(v any, q string, args ...any) error {
  40. return query(conn.db, func(rows *sql.Rows) error {
  41. return unmarshalRow(v, rows, true)
  42. }, q, args...)
  43. }
  44. // QueryRowCtx executes sql and returns a query row
  45. func (conn *MockConn) QueryRowCtx(_ context.Context, v any, query string, args ...any) error {
  46. return conn.QueryRow(v, query, args...)
  47. }
  48. // QueryRowPartial executes sql and returns a partial query row
  49. func (conn *MockConn) QueryRowPartial(v any, q string, args ...any) error {
  50. return query(conn.db, func(rows *sql.Rows) error {
  51. return unmarshalRow(v, rows, false)
  52. }, q, args...)
  53. }
  54. // QueryRowPartialCtx executes sql and returns a partial query row
  55. func (conn *MockConn) QueryRowPartialCtx(_ context.Context, v any, query string, args ...any) error {
  56. return conn.QueryRowPartial(v, query, args...)
  57. }
  58. // QueryRows executes sql and returns query rows
  59. func (conn *MockConn) QueryRows(v any, q string, args ...any) error {
  60. return query(conn.db, func(rows *sql.Rows) error {
  61. return unmarshalRows(v, rows, true)
  62. }, q, args...)
  63. }
  64. // QueryRowsCtx executes sql and returns query rows
  65. func (conn *MockConn) QueryRowsCtx(_ context.Context, v any, query string, args ...any) error {
  66. return conn.QueryRows(v, query, args...)
  67. }
  68. // QueryRowsPartial executes sql and returns partial query rows
  69. func (conn *MockConn) QueryRowsPartial(v any, q string, args ...any) error {
  70. return query(conn.db, func(rows *sql.Rows) error {
  71. return unmarshalRows(v, rows, false)
  72. }, q, args...)
  73. }
  74. // QueryRowsPartialCtx executes sql and returns partial query rows
  75. func (conn *MockConn) QueryRowsPartialCtx(_ context.Context, v any, query string, args ...any) error {
  76. return conn.QueryRowsPartial(v, query, args...)
  77. }
  78. // RawDB returns the underlying sql.DB.
  79. func (conn *MockConn) RawDB() (*sql.DB, error) {
  80. return conn.db, nil
  81. }
  82. // Transact is the implemention of sqlx.SqlConn, nothing to do
  83. func (conn *MockConn) Transact(func(session sqlx.Session) error) error {
  84. return nil
  85. }
  86. // TransactCtx is the implemention of sqlx.SqlConn, nothing to do
  87. func (conn *MockConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
  88. return nil
  89. }
  90. func (s statement) Close() error {
  91. return s.stmt.Close()
  92. }
  93. func (s statement) Exec(args ...any) (sql.Result, error) {
  94. return execStmt(s.stmt, args...)
  95. }
  96. func (s statement) ExecCtx(_ context.Context, args ...any) (sql.Result, error) {
  97. return s.Exec(args...)
  98. }
  99. func (s statement) QueryRow(v any, args ...any) error {
  100. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  101. return unmarshalRow(v, rows, true)
  102. }, args...)
  103. }
  104. func (s statement) QueryRowCtx(_ context.Context, v any, args ...any) error {
  105. return s.QueryRow(v, args...)
  106. }
  107. func (s statement) QueryRowPartial(v any, args ...any) error {
  108. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  109. return unmarshalRow(v, rows, false)
  110. }, args...)
  111. }
  112. func (s statement) QueryRowPartialCtx(_ context.Context, v any, args ...any) error {
  113. return s.QueryRowPartial(v, args...)
  114. }
  115. func (s statement) QueryRows(v any, args ...any) error {
  116. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  117. return unmarshalRows(v, rows, true)
  118. }, args...)
  119. }
  120. func (s statement) QueryRowsCtx(_ context.Context, v any, args ...any) error {
  121. return s.QueryRows(v, args...)
  122. }
  123. func (s statement) QueryRowsPartial(v any, args ...any) error {
  124. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  125. return unmarshalRows(v, rows, false)
  126. }, args...)
  127. }
  128. func (s statement) QueryRowsPartialCtx(_ context.Context, v any, args ...any) error {
  129. return s.QueryRowsPartial(v, args...)
  130. }