stmt_test.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. )
  10. var errMockedPlaceholder = errors.New("placeholder")
  11. func TestStmt_exec(t *testing.T) {
  12. tests := []struct {
  13. name string
  14. query string
  15. args []interface{}
  16. delay bool
  17. hasError bool
  18. err error
  19. lastInsertId int64
  20. rowsAffected int64
  21. }{
  22. {
  23. name: "normal",
  24. query: "select user from users where id=?",
  25. args: []interface{}{1},
  26. lastInsertId: 1,
  27. rowsAffected: 2,
  28. },
  29. {
  30. name: "exec error",
  31. query: "select user from users where id=?",
  32. args: []interface{}{1},
  33. hasError: true,
  34. err: errors.New("exec"),
  35. },
  36. {
  37. name: "exec more args error",
  38. query: "select user from users where id=? and name=?",
  39. args: []interface{}{1},
  40. hasError: true,
  41. err: errors.New("exec"),
  42. },
  43. {
  44. name: "slowcall",
  45. query: "select user from users where id=?",
  46. args: []interface{}{1},
  47. delay: true,
  48. lastInsertId: 1,
  49. rowsAffected: 2,
  50. },
  51. }
  52. for _, test := range tests {
  53. test := test
  54. fns := []func(args ...interface{}) (sql.Result, error){
  55. func(args ...interface{}) (sql.Result, error) {
  56. return exec(context.Background(), &mockedSessionConn{
  57. lastInsertId: test.lastInsertId,
  58. rowsAffected: test.rowsAffected,
  59. err: test.err,
  60. delay: test.delay,
  61. }, test.query, args...)
  62. },
  63. func(args ...interface{}) (sql.Result, error) {
  64. return execStmt(context.Background(), &mockedStmtConn{
  65. lastInsertId: test.lastInsertId,
  66. rowsAffected: test.rowsAffected,
  67. err: test.err,
  68. delay: test.delay,
  69. }, test.query, args...)
  70. },
  71. }
  72. for _, fn := range fns {
  73. fn := fn
  74. t.Run(test.name, func(t *testing.T) {
  75. t.Parallel()
  76. res, err := fn(test.args...)
  77. if test.hasError {
  78. assert.NotNil(t, err)
  79. return
  80. }
  81. assert.Nil(t, err)
  82. lastInsertId, err := res.LastInsertId()
  83. assert.Nil(t, err)
  84. assert.Equal(t, test.lastInsertId, lastInsertId)
  85. rowsAffected, err := res.RowsAffected()
  86. assert.Nil(t, err)
  87. assert.Equal(t, test.rowsAffected, rowsAffected)
  88. })
  89. }
  90. }
  91. }
  92. func TestStmt_query(t *testing.T) {
  93. tests := []struct {
  94. name string
  95. query string
  96. args []interface{}
  97. delay bool
  98. hasError bool
  99. err error
  100. }{
  101. {
  102. name: "normal",
  103. query: "select user from users where id=?",
  104. args: []interface{}{1},
  105. },
  106. {
  107. name: "query error",
  108. query: "select user from users where id=?",
  109. args: []interface{}{1},
  110. hasError: true,
  111. err: errors.New("exec"),
  112. },
  113. {
  114. name: "query more args error",
  115. query: "select user from users where id=? and name=?",
  116. args: []interface{}{1},
  117. hasError: true,
  118. err: errors.New("exec"),
  119. },
  120. {
  121. name: "slowcall",
  122. query: "select user from users where id=?",
  123. args: []interface{}{1},
  124. delay: true,
  125. },
  126. }
  127. for _, test := range tests {
  128. test := test
  129. fns := []func(args ...interface{}) error{
  130. func(args ...interface{}) error {
  131. return query(context.Background(), &mockedSessionConn{
  132. err: test.err,
  133. delay: test.delay,
  134. }, func(rows *sql.Rows) error {
  135. return nil
  136. }, test.query, args...)
  137. },
  138. func(args ...interface{}) error {
  139. return queryStmt(context.Background(), &mockedStmtConn{
  140. err: test.err,
  141. delay: test.delay,
  142. }, func(rows *sql.Rows) error {
  143. return nil
  144. }, test.query, args...)
  145. },
  146. }
  147. for _, fn := range fns {
  148. fn := fn
  149. t.Run(test.name, func(t *testing.T) {
  150. t.Parallel()
  151. err := fn(test.args...)
  152. if test.hasError {
  153. assert.NotNil(t, err)
  154. return
  155. }
  156. assert.NotNil(t, err)
  157. })
  158. }
  159. }
  160. }
  161. func TestSetSlowThreshold(t *testing.T) {
  162. assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
  163. SetSlowThreshold(time.Second)
  164. assert.Equal(t, time.Second, slowThreshold.Load())
  165. }
  166. type mockedSessionConn struct {
  167. lastInsertId int64
  168. rowsAffected int64
  169. err error
  170. delay bool
  171. }
  172. func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
  173. return m.ExecContext(context.Background(), query, args...)
  174. }
  175. func (m *mockedSessionConn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
  176. if m.delay {
  177. time.Sleep(defaultSlowThreshold + time.Millisecond)
  178. }
  179. return mockedResult{
  180. lastInsertId: m.lastInsertId,
  181. rowsAffected: m.rowsAffected,
  182. }, m.err
  183. }
  184. func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
  185. return m.QueryContext(context.Background(), query, args...)
  186. }
  187. func (m *mockedSessionConn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
  188. if m.delay {
  189. time.Sleep(defaultSlowThreshold + time.Millisecond)
  190. }
  191. err := errMockedPlaceholder
  192. if m.err != nil {
  193. err = m.err
  194. }
  195. return new(sql.Rows), err
  196. }
  197. type mockedStmtConn struct {
  198. lastInsertId int64
  199. rowsAffected int64
  200. err error
  201. delay bool
  202. }
  203. func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
  204. return m.ExecContext(context.Background(), args...)
  205. }
  206. func (m *mockedStmtConn) ExecContext(_ context.Context, _ ...interface{}) (sql.Result, error) {
  207. if m.delay {
  208. time.Sleep(defaultSlowThreshold + time.Millisecond)
  209. }
  210. return mockedResult{
  211. lastInsertId: m.lastInsertId,
  212. rowsAffected: m.rowsAffected,
  213. }, m.err
  214. }
  215. func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
  216. return m.QueryContext(context.Background(), args...)
  217. }
  218. func (m *mockedStmtConn) QueryContext(_ context.Context, _ ...interface{}) (*sql.Rows, error) {
  219. if m.delay {
  220. time.Sleep(defaultSlowThreshold + time.Millisecond)
  221. }
  222. err := errMockedPlaceholder
  223. if m.err != nil {
  224. err = m.err
  225. }
  226. return new(sql.Rows), err
  227. }
  228. type mockedResult struct {
  229. lastInsertId int64
  230. rowsAffected int64
  231. }
  232. func (m mockedResult) LastInsertId() (int64, error) {
  233. return m.lastInsertId, nil
  234. }
  235. func (m mockedResult) RowsAffected() (int64, error) {
  236. return m.rowsAffected, nil
  237. }