sqlconn_test.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "errors"
  5. "io"
  6. "testing"
  7. "github.com/DATA-DOG/go-sqlmock"
  8. "github.com/stretchr/testify/assert"
  9. "github.com/zeromicro/go-zero/core/breaker"
  10. "github.com/zeromicro/go-zero/core/logx"
  11. "github.com/zeromicro/go-zero/core/stores/dbtest"
  12. "github.com/zeromicro/go-zero/core/trace/tracetest"
  13. )
  14. const mockedDatasource = "sqlmock"
  15. func init() {
  16. logx.Disable()
  17. }
  18. func TestSqlConn(t *testing.T) {
  19. me := tracetest.NewInMemoryExporter(t)
  20. mock, err := buildConn()
  21. assert.Nil(t, err)
  22. mock.ExpectExec("any")
  23. mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
  24. conn := NewMysql(mockedDatasource)
  25. db, err := conn.RawDB()
  26. assert.Nil(t, err)
  27. rawConn := NewSqlConnFromDB(db, withMysqlAcceptable())
  28. badConn := NewMysql("badsql")
  29. _, err = conn.Exec("any", "value")
  30. assert.NotNil(t, err)
  31. _, err = badConn.Exec("any", "value")
  32. assert.NotNil(t, err)
  33. _, err = rawConn.Prepare("any")
  34. assert.NotNil(t, err)
  35. _, err = badConn.Prepare("any")
  36. assert.NotNil(t, err)
  37. var val string
  38. assert.NotNil(t, conn.QueryRow(&val, "any"))
  39. assert.NotNil(t, badConn.QueryRow(&val, "any"))
  40. assert.NotNil(t, conn.QueryRowPartial(&val, "any"))
  41. assert.NotNil(t, badConn.QueryRowPartial(&val, "any"))
  42. assert.NotNil(t, conn.QueryRows(&val, "any"))
  43. assert.NotNil(t, badConn.QueryRows(&val, "any"))
  44. assert.NotNil(t, conn.QueryRowsPartial(&val, "any"))
  45. assert.NotNil(t, badConn.QueryRowsPartial(&val, "any"))
  46. assert.NotNil(t, conn.Transact(func(session Session) error {
  47. return nil
  48. }))
  49. assert.NotNil(t, badConn.Transact(func(session Session) error {
  50. return nil
  51. }))
  52. assert.Equal(t, 14, len(me.GetSpans()))
  53. }
  54. func TestSqlConn_RawDB(t *testing.T) {
  55. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  56. rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
  57. mock.ExpectQuery("any").WillReturnRows(rows)
  58. conn := NewSqlConnFromDB(db)
  59. var val string
  60. assert.NoError(t, conn.QueryRow(&val, "any"))
  61. assert.Equal(t, "bar", val)
  62. })
  63. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  64. rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
  65. mock.ExpectQuery("any").WillReturnRows(rows)
  66. conn := NewSqlConnFromDB(db)
  67. var val string
  68. assert.NoError(t, conn.QueryRowPartial(&val, "any"))
  69. assert.Equal(t, "bar", val)
  70. })
  71. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  72. rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
  73. mock.ExpectQuery("any").WillReturnRows(rows)
  74. conn := NewSqlConnFromDB(db)
  75. var vals []string
  76. assert.NoError(t, conn.QueryRows(&vals, "any"))
  77. assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
  78. })
  79. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  80. rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
  81. mock.ExpectQuery("any").WillReturnRows(rows)
  82. conn := NewSqlConnFromDB(db)
  83. var vals []string
  84. assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
  85. assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
  86. })
  87. }
  88. func TestSqlConn_Errors(t *testing.T) {
  89. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  90. conn := NewSqlConnFromDB(db)
  91. conn.(*commonSqlConn).connProv = func() (*sql.DB, error) {
  92. return nil, errors.New("error")
  93. }
  94. _, err := conn.Prepare("any")
  95. assert.Error(t, err)
  96. })
  97. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  98. mock.ExpectExec("any").WillReturnError(breaker.ErrServiceUnavailable)
  99. conn := NewSqlConnFromDB(db)
  100. _, err := conn.Exec("any")
  101. assert.Error(t, err)
  102. })
  103. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  104. mock.ExpectPrepare("any").WillReturnError(breaker.ErrServiceUnavailable)
  105. conn := NewSqlConnFromDB(db)
  106. _, err := conn.Prepare("any")
  107. assert.Error(t, err)
  108. })
  109. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  110. mock.ExpectBegin()
  111. mock.ExpectRollback()
  112. conn := NewSqlConnFromDB(db)
  113. err := conn.Transact(func(session Session) error {
  114. return breaker.ErrServiceUnavailable
  115. })
  116. assert.Equal(t, breaker.ErrServiceUnavailable, err)
  117. })
  118. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  119. mock.ExpectQuery("any").WillReturnError(breaker.ErrServiceUnavailable)
  120. conn := NewSqlConnFromDB(db)
  121. var vals []string
  122. err := conn.QueryRows(&vals, "any")
  123. assert.Equal(t, breaker.ErrServiceUnavailable, err)
  124. })
  125. }
  126. func TestStatement(t *testing.T) {
  127. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  128. mock.ExpectPrepare("any").WillBeClosed()
  129. conn := NewSqlConnFromDB(db)
  130. stmt, err := conn.Prepare("any")
  131. assert.NoError(t, err)
  132. assert.NoError(t, stmt.Close())
  133. })
  134. dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
  135. mock.ExpectPrepare("any").WillBeClosed()
  136. stmt, err := tx.Prepare("any")
  137. assert.NoError(t, err)
  138. st := statement{
  139. query: "foo",
  140. stmt: stmt,
  141. }
  142. assert.NoError(t, st.Close())
  143. })
  144. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  145. mock.ExpectPrepare("any")
  146. mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
  147. conn := NewSqlConnFromDB(db)
  148. stmt, err := conn.Prepare("any")
  149. assert.NoError(t, err)
  150. res, err := stmt.Exec()
  151. assert.NoError(t, err)
  152. lastInsertID, err := res.LastInsertId()
  153. assert.NoError(t, err)
  154. assert.Equal(t, int64(2), lastInsertID)
  155. rowsAffected, err := res.RowsAffected()
  156. assert.NoError(t, err)
  157. assert.Equal(t, int64(3), rowsAffected)
  158. })
  159. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  160. mock.ExpectPrepare("any")
  161. row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
  162. mock.ExpectQuery("any").WillReturnRows(row)
  163. conn := NewSqlConnFromDB(db)
  164. stmt, err := conn.Prepare("any")
  165. assert.NoError(t, err)
  166. var val string
  167. err = stmt.QueryRow(&val)
  168. assert.NoError(t, err)
  169. assert.Equal(t, "bar", val)
  170. })
  171. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  172. mock.ExpectPrepare("any")
  173. row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
  174. mock.ExpectQuery("any").WillReturnRows(row)
  175. conn := NewSqlConnFromDB(db)
  176. stmt, err := conn.Prepare("any")
  177. assert.NoError(t, err)
  178. var val string
  179. err = stmt.QueryRowPartial(&val)
  180. assert.NoError(t, err)
  181. assert.Equal(t, "bar", val)
  182. })
  183. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  184. mock.ExpectPrepare("any")
  185. rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
  186. mock.ExpectQuery("any").WillReturnRows(rows)
  187. conn := NewSqlConnFromDB(db)
  188. stmt, err := conn.Prepare("any")
  189. assert.NoError(t, err)
  190. var vals []string
  191. assert.NoError(t, stmt.QueryRows(&vals))
  192. assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
  193. })
  194. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  195. mock.ExpectPrepare("any")
  196. rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
  197. mock.ExpectQuery("any").WillReturnRows(rows)
  198. conn := NewSqlConnFromDB(db)
  199. stmt, err := conn.Prepare("any")
  200. assert.NoError(t, err)
  201. var vals []string
  202. assert.NoError(t, stmt.QueryRowsPartial(&vals))
  203. assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
  204. })
  205. }
  206. func TestBreakerWithFormatError(t *testing.T) {
  207. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  208. conn := NewSqlConnFromDB(db, withMysqlAcceptable())
  209. for i := 0; i < 1000; i++ {
  210. var val string
  211. if !assert.NotEqual(t, breaker.ErrServiceUnavailable,
  212. conn.QueryRow(&val, "any ?, ?", "foo")) {
  213. break
  214. }
  215. }
  216. })
  217. }
  218. func TestBreakerWithScanError(t *testing.T) {
  219. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  220. conn := NewSqlConnFromDB(db, withMysqlAcceptable())
  221. for i := 0; i < 1000; i++ {
  222. rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
  223. mock.ExpectQuery("any").WillReturnRows(rows)
  224. var val int
  225. if !assert.NotEqual(t, breaker.ErrServiceUnavailable, conn.QueryRow(&val, "any")) {
  226. break
  227. }
  228. }
  229. })
  230. }
  231. func buildConn() (mock sqlmock.Sqlmock, err error) {
  232. _, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
  233. var db *sql.DB
  234. var err error
  235. db, mock, err = sqlmock.New()
  236. return db, err
  237. })
  238. return
  239. }