bulkinserter_test.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "strconv"
  7. "strings"
  8. "sync"
  9. "sync/atomic"
  10. "testing"
  11. "github.com/DATA-DOG/go-sqlmock"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/wuntsong-org/go-zero-plus/core/stores/dbtest"
  14. )
  15. type mockedConn struct {
  16. query string
  17. args []any
  18. execErr error
  19. updateCallback func(query string, args []any)
  20. }
  21. func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...any) (sql.Result, error) {
  22. c.query = query
  23. c.args = args
  24. if c.updateCallback != nil {
  25. c.updateCallback(query, args)
  26. }
  27. return nil, c.execErr
  28. }
  29. func (c *mockedConn) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
  30. panic("implement me")
  31. }
  32. func (c *mockedConn) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error {
  33. panic("implement me")
  34. }
  35. func (c *mockedConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
  36. panic("implement me")
  37. }
  38. func (c *mockedConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
  39. panic("implement me")
  40. }
  41. func (c *mockedConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
  42. panic("implement me")
  43. }
  44. func (c *mockedConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
  45. panic("should not called")
  46. }
  47. func (c *mockedConn) Exec(query string, args ...any) (sql.Result, error) {
  48. return c.ExecCtx(context.Background(), query, args...)
  49. }
  50. func (c *mockedConn) Prepare(query string) (StmtSession, error) {
  51. panic("should not called")
  52. }
  53. func (c *mockedConn) QueryRow(v any, query string, args ...any) error {
  54. panic("should not called")
  55. }
  56. func (c *mockedConn) QueryRowPartial(v any, query string, args ...any) error {
  57. panic("should not called")
  58. }
  59. func (c *mockedConn) QueryRows(v any, query string, args ...any) error {
  60. panic("should not called")
  61. }
  62. func (c *mockedConn) QueryRowsPartial(v any, query string, args ...any) error {
  63. panic("should not called")
  64. }
  65. func (c *mockedConn) RawDB() (*sql.DB, error) {
  66. panic("should not called")
  67. }
  68. func (c *mockedConn) Transact(func(session Session) error) error {
  69. panic("should not called")
  70. }
  71. func TestBulkInserter(t *testing.T) {
  72. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  73. var conn mockedConn
  74. inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
  75. assert.Nil(t, err)
  76. for i := 0; i < 5; i++ {
  77. assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
  78. }
  79. inserter.Flush()
  80. assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
  81. `('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
  82. `('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`,
  83. conn.query)
  84. assert.Nil(t, conn.args)
  85. })
  86. }
  87. func TestBulkInserterSuffix(t *testing.T) {
  88. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  89. var conn mockedConn
  90. inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
  91. `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
  92. assert.Nil(t, err)
  93. assert.Nil(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user, count) VALUES`+
  94. `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`))
  95. for i := 0; i < 5; i++ {
  96. assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
  97. }
  98. inserter.SetResultHandler(func(result sql.Result, err error) {})
  99. inserter.Flush()
  100. assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
  101. `('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
  102. `('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`,
  103. conn.query)
  104. assert.Nil(t, conn.args)
  105. })
  106. }
  107. func TestBulkInserterBadStatement(t *testing.T) {
  108. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  109. var conn mockedConn
  110. _, err := NewBulkInserter(&conn, "foo")
  111. assert.NotNil(t, err)
  112. })
  113. }
  114. func TestBulkInserter_Update(t *testing.T) {
  115. conn := mockedConn{
  116. execErr: errors.New("foo"),
  117. }
  118. _, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES()`)
  119. assert.NotNil(t, err)
  120. _, err = NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?)`)
  121. assert.NotNil(t, err)
  122. inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
  123. assert.Nil(t, err)
  124. inserter.inserter.Execute([]string{"bar"})
  125. inserter.SetResultHandler(func(result sql.Result, err error) {
  126. })
  127. inserter.UpdateOrDelete(func() {})
  128. inserter.inserter.Execute([]string(nil))
  129. assert.NotNil(t, inserter.UpdateStmt("foo"))
  130. assert.NotNil(t, inserter.Insert("foo", "bar"))
  131. }
  132. func TestBulkInserter_UpdateStmt(t *testing.T) {
  133. var updated int32
  134. conn := mockedConn{
  135. execErr: errors.New("foo"),
  136. updateCallback: func(query string, args []any) {
  137. count := atomic.AddInt32(&updated, 1)
  138. assert.Empty(t, args)
  139. assert.Equal(t, 100, strings.Count(query, "foo"))
  140. if count == 1 {
  141. assert.Equal(t, 0, strings.Count(query, "bar"))
  142. } else {
  143. assert.Equal(t, 100, strings.Count(query, "bar"))
  144. }
  145. },
  146. }
  147. inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom) VALUES(?)`)
  148. assert.NoError(t, err)
  149. var wg1 sync.WaitGroup
  150. wg1.Add(2)
  151. for i := 0; i < 2; i++ {
  152. go func() {
  153. defer wg1.Done()
  154. for i := 0; i < 50; i++ {
  155. assert.NoError(t, inserter.Insert("foo"))
  156. }
  157. }()
  158. }
  159. wg1.Wait()
  160. assert.NoError(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user) VALUES(?, ?)`))
  161. var wg2 sync.WaitGroup
  162. wg2.Add(1)
  163. go func() {
  164. defer wg2.Done()
  165. for i := 0; i < 100; i++ {
  166. assert.NoError(t, inserter.Insert("foo", "bar"))
  167. }
  168. inserter.Flush()
  169. }()
  170. wg2.Wait()
  171. assert.Equal(t, int32(2), atomic.LoadInt32(&updated))
  172. }