bulkinserter_test.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "errors"
  5. "strconv"
  6. "testing"
  7. "github.com/DATA-DOG/go-sqlmock"
  8. "github.com/stretchr/testify/assert"
  9. "github.com/tal-tech/go-zero/core/logx"
  10. )
  11. type mockedConn struct {
  12. query string
  13. args []interface{}
  14. execErr error
  15. }
  16. func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
  17. c.query = query
  18. c.args = args
  19. return nil, c.execErr
  20. }
  21. func (c *mockedConn) Prepare(query string) (StmtSession, error) {
  22. panic("should not called")
  23. }
  24. func (c *mockedConn) QueryRow(v interface{}, query string, args ...interface{}) error {
  25. panic("should not called")
  26. }
  27. func (c *mockedConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error {
  28. panic("should not called")
  29. }
  30. func (c *mockedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
  31. panic("should not called")
  32. }
  33. func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error {
  34. panic("should not called")
  35. }
  36. func (c *mockedConn) RawDB() (*sql.DB, error) {
  37. panic("should not called")
  38. }
  39. func (c *mockedConn) Transact(func(session Session) error) error {
  40. panic("should not called")
  41. }
  42. func TestBulkInserter(t *testing.T) {
  43. runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  44. var conn mockedConn
  45. inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
  46. assert.Nil(t, err)
  47. for i := 0; i < 5; i++ {
  48. assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
  49. }
  50. inserter.Flush()
  51. assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
  52. `('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
  53. `('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`,
  54. conn.query)
  55. assert.Nil(t, conn.args)
  56. })
  57. }
  58. func TestBulkInserterSuffix(t *testing.T) {
  59. runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  60. var conn mockedConn
  61. inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
  62. `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
  63. assert.Nil(t, err)
  64. assert.Nil(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user, count) VALUES`+
  65. `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`))
  66. for i := 0; i < 5; i++ {
  67. assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
  68. }
  69. inserter.SetResultHandler(func(result sql.Result, err error) {})
  70. inserter.Flush()
  71. assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
  72. `('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
  73. `('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`,
  74. conn.query)
  75. assert.Nil(t, conn.args)
  76. })
  77. }
  78. func TestBulkInserterBadStatement(t *testing.T) {
  79. runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  80. var conn mockedConn
  81. _, err := NewBulkInserter(&conn, "foo")
  82. assert.NotNil(t, err)
  83. })
  84. }
  85. func TestBulkInserter_Update(t *testing.T) {
  86. conn := mockedConn{
  87. execErr: errors.New("foo"),
  88. }
  89. _, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES()`)
  90. assert.NotNil(t, err)
  91. _, err = NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?)`)
  92. assert.NotNil(t, err)
  93. inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
  94. assert.Nil(t, err)
  95. inserter.inserter.Execute([]string{"bar"})
  96. inserter.SetResultHandler(func(result sql.Result, err error) {
  97. })
  98. inserter.UpdateOrDelete(func() {})
  99. inserter.inserter.Execute([]string(nil))
  100. assert.NotNil(t, inserter.UpdateStmt("foo"))
  101. assert.NotNil(t, inserter.Insert("foo", "bar"))
  102. }
  103. func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
  104. logx.Disable()
  105. db, mock, err := sqlmock.New()
  106. if err != nil {
  107. t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
  108. }
  109. defer db.Close()
  110. fn(db, mock)
  111. if err := mock.ExpectationsWereMet(); err != nil {
  112. t.Errorf("there were unfulfilled expectations: %s", err)
  113. }
  114. }