|
@@ -5,6 +5,9 @@ import (
|
|
|
"database/sql"
|
|
|
"errors"
|
|
|
"strconv"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
+ "sync/atomic"
|
|
|
"testing"
|
|
|
|
|
|
"github.com/DATA-DOG/go-sqlmock"
|
|
@@ -13,14 +16,19 @@ import (
|
|
|
)
|
|
|
|
|
|
type mockedConn struct {
|
|
|
- query string
|
|
|
- args []any
|
|
|
- execErr error
|
|
|
+ query string
|
|
|
+ args []any
|
|
|
+ execErr error
|
|
|
+ updateCallback func(query string, args []any)
|
|
|
}
|
|
|
|
|
|
func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...any) (sql.Result, error) {
|
|
|
c.query = query
|
|
|
c.args = args
|
|
|
+ if c.updateCallback != nil {
|
|
|
+ c.updateCallback(query, args)
|
|
|
+ }
|
|
|
+
|
|
|
return nil, c.execErr
|
|
|
}
|
|
|
|
|
@@ -144,3 +152,50 @@ func TestBulkInserter_Update(t *testing.T) {
|
|
|
assert.NotNil(t, inserter.UpdateStmt("foo"))
|
|
|
assert.NotNil(t, inserter.Insert("foo", "bar"))
|
|
|
}
|
|
|
+
|
|
|
+func TestBulkInserter_UpdateStmt(t *testing.T) {
|
|
|
+ var updated int32
|
|
|
+ conn := mockedConn{
|
|
|
+ execErr: errors.New("foo"),
|
|
|
+ updateCallback: func(query string, args []any) {
|
|
|
+ count := atomic.AddInt32(&updated, 1)
|
|
|
+ assert.Empty(t, args)
|
|
|
+ assert.Equal(t, 100, strings.Count(query, "foo"))
|
|
|
+ if count == 1 {
|
|
|
+ assert.Equal(t, 0, strings.Count(query, "bar"))
|
|
|
+ } else {
|
|
|
+ assert.Equal(t, 100, strings.Count(query, "bar"))
|
|
|
+ }
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom) VALUES(?)`)
|
|
|
+ assert.NoError(t, err)
|
|
|
+
|
|
|
+ var wg1 sync.WaitGroup
|
|
|
+ wg1.Add(2)
|
|
|
+ for i := 0; i < 2; i++ {
|
|
|
+ go func() {
|
|
|
+ defer wg1.Done()
|
|
|
+ for i := 0; i < 50; i++ {
|
|
|
+ assert.NoError(t, inserter.Insert("foo"))
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ }
|
|
|
+ wg1.Wait()
|
|
|
+
|
|
|
+ assert.NoError(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user) VALUES(?, ?)`))
|
|
|
+
|
|
|
+ var wg2 sync.WaitGroup
|
|
|
+ wg2.Add(1)
|
|
|
+ go func() {
|
|
|
+ defer wg2.Done()
|
|
|
+ for i := 0; i < 100; i++ {
|
|
|
+ assert.NoError(t, inserter.Insert("foo", "bar"))
|
|
|
+ }
|
|
|
+ inserter.Flush()
|
|
|
+ }()
|
|
|
+ wg2.Wait()
|
|
|
+
|
|
|
+ assert.Equal(t, int32(2), atomic.LoadInt32(&updated))
|
|
|
+}
|