bulkinserter_test.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package sqlx
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "strconv"
  7. "testing"
  8. "github.com/DATA-DOG/go-sqlmock"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/zeromicro/go-zero/internal/dbtest"
  11. )
  12. type mockedConn struct {
  13. query string
  14. args []any
  15. execErr error
  16. }
  17. func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...any) (sql.Result, error) {
  18. c.query = query
  19. c.args = args
  20. return nil, c.execErr
  21. }
  22. func (c *mockedConn) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
  23. panic("implement me")
  24. }
  25. func (c *mockedConn) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error {
  26. panic("implement me")
  27. }
  28. func (c *mockedConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
  29. panic("implement me")
  30. }
  31. func (c *mockedConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
  32. panic("implement me")
  33. }
  34. func (c *mockedConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
  35. panic("implement me")
  36. }
  37. func (c *mockedConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
  38. panic("should not called")
  39. }
  40. func (c *mockedConn) Exec(query string, args ...any) (sql.Result, error) {
  41. return c.ExecCtx(context.Background(), query, args...)
  42. }
  43. func (c *mockedConn) Prepare(query string) (StmtSession, error) {
  44. panic("should not called")
  45. }
  46. func (c *mockedConn) QueryRow(v any, query string, args ...any) error {
  47. panic("should not called")
  48. }
  49. func (c *mockedConn) QueryRowPartial(v any, query string, args ...any) error {
  50. panic("should not called")
  51. }
  52. func (c *mockedConn) QueryRows(v any, query string, args ...any) error {
  53. panic("should not called")
  54. }
  55. func (c *mockedConn) QueryRowsPartial(v any, query string, args ...any) error {
  56. panic("should not called")
  57. }
  58. func (c *mockedConn) RawDB() (*sql.DB, error) {
  59. panic("should not called")
  60. }
  61. func (c *mockedConn) Transact(func(session Session) error) error {
  62. panic("should not called")
  63. }
  64. func TestBulkInserter(t *testing.T) {
  65. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  66. var conn mockedConn
  67. inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
  68. assert.Nil(t, err)
  69. for i := 0; i < 5; i++ {
  70. assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
  71. }
  72. inserter.Flush()
  73. assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
  74. `('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
  75. `('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`,
  76. conn.query)
  77. assert.Nil(t, conn.args)
  78. })
  79. }
  80. func TestBulkInserterSuffix(t *testing.T) {
  81. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  82. var conn mockedConn
  83. inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
  84. `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
  85. assert.Nil(t, err)
  86. assert.Nil(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user, count) VALUES`+
  87. `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`))
  88. for i := 0; i < 5; i++ {
  89. assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
  90. }
  91. inserter.SetResultHandler(func(result sql.Result, err error) {})
  92. inserter.Flush()
  93. assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
  94. `('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
  95. `('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`,
  96. conn.query)
  97. assert.Nil(t, conn.args)
  98. })
  99. }
  100. func TestBulkInserterBadStatement(t *testing.T) {
  101. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  102. var conn mockedConn
  103. _, err := NewBulkInserter(&conn, "foo")
  104. assert.NotNil(t, err)
  105. })
  106. }
  107. func TestBulkInserter_Update(t *testing.T) {
  108. conn := mockedConn{
  109. execErr: errors.New("foo"),
  110. }
  111. _, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES()`)
  112. assert.NotNil(t, err)
  113. _, err = NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?)`)
  114. assert.NotNil(t, err)
  115. inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
  116. assert.Nil(t, err)
  117. inserter.inserter.Execute([]string{"bar"})
  118. inserter.SetResultHandler(func(result sql.Result, err error) {
  119. })
  120. inserter.UpdateOrDelete(func() {})
  121. inserter.inserter.Execute([]string(nil))
  122. assert.NotNil(t, inserter.UpdateStmt("foo"))
  123. assert.NotNil(t, inserter.Insert("foo", "bar"))
  124. }