浏览代码

fix: UpdateStmt doesn't update the statement correctly in sqlx/bulkinserter.go (#3607)

Kevin Wan 1 年之前
父节点
当前提交
abd1fa96a9
共有 3 个文件被更改,包括 69 次插入4 次删除
  1. 1 1
      core/mapping/utils.go
  2. 10 0
      core/stores/sqlx/bulkinserter.go
  3. 58 3
      core/stores/sqlx/bulkinserter_test.go

+ 1 - 1
core/mapping/utils.go

@@ -30,7 +30,7 @@ const (
 	leftSquareBracket  = '['
 	rightSquareBracket = ']'
 	segmentSeparator   = ','
-	intSize            = 32 << (^uint(0) >> 63)
+	intSize            = 32 << (^uint(0) >> 63) // 32 or 64
 )
 
 var (

+ 10 - 0
core/stores/sqlx/bulkinserter.go

@@ -4,6 +4,7 @@ import (
 	"database/sql"
 	"fmt"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/zeromicro/go-zero/core/executors"
@@ -30,6 +31,7 @@ type (
 		executor *executors.PeriodicalExecutor
 		inserter *dbInserter
 		stmt     bulkStmt
+		lock     sync.RWMutex // guards stmt
 	}
 
 	bulkStmt struct {
@@ -65,6 +67,9 @@ func (bi *BulkInserter) Flush() {
 
 // Insert inserts given args.
 func (bi *BulkInserter) Insert(args ...any) error {
+	bi.lock.RLock()
+	defer bi.lock.RUnlock()
+
 	value, err := format(bi.stmt.valueFormat, args...)
 	if err != nil {
 		return err
@@ -95,6 +100,11 @@ func (bi *BulkInserter) UpdateStmt(stmt string) error {
 		return err
 	}
 
+	bi.lock.Lock()
+	defer bi.lock.Unlock()
+
+	// with write lock, it doesn't matter what's the order of setting bi.stmt and calling flush.
+	bi.stmt = bkStmt
 	bi.executor.Flush()
 	bi.executor.Sync(func() {
 		bi.inserter.stmt = bkStmt

+ 58 - 3
core/stores/sqlx/bulkinserter_test.go

@@ -5,6 +5,9 @@ import (
 	"database/sql"
 	"errors"
 	"strconv"
+	"strings"
+	"sync"
+	"sync/atomic"
 	"testing"
 
 	"github.com/DATA-DOG/go-sqlmock"
@@ -13,14 +16,19 @@ import (
 )
 
 type mockedConn struct {
-	query   string
-	args    []any
-	execErr error
+	query          string
+	args           []any
+	execErr        error
+	updateCallback func(query string, args []any)
 }
 
 func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...any) (sql.Result, error) {
 	c.query = query
 	c.args = args
+	if c.updateCallback != nil {
+		c.updateCallback(query, args)
+	}
+
 	return nil, c.execErr
 }
 
@@ -144,3 +152,50 @@ func TestBulkInserter_Update(t *testing.T) {
 	assert.NotNil(t, inserter.UpdateStmt("foo"))
 	assert.NotNil(t, inserter.Insert("foo", "bar"))
 }
+
+func TestBulkInserter_UpdateStmt(t *testing.T) {
+	var updated int32
+	conn := mockedConn{
+		execErr: errors.New("foo"),
+		updateCallback: func(query string, args []any) {
+			count := atomic.AddInt32(&updated, 1)
+			assert.Empty(t, args)
+			assert.Equal(t, 100, strings.Count(query, "foo"))
+			if count == 1 {
+				assert.Equal(t, 0, strings.Count(query, "bar"))
+			} else {
+				assert.Equal(t, 100, strings.Count(query, "bar"))
+			}
+		},
+	}
+
+	inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom) VALUES(?)`)
+	assert.NoError(t, err)
+
+	var wg1 sync.WaitGroup
+	wg1.Add(2)
+	for i := 0; i < 2; i++ {
+		go func() {
+			defer wg1.Done()
+			for i := 0; i < 50; i++ {
+				assert.NoError(t, inserter.Insert("foo"))
+			}
+		}()
+	}
+	wg1.Wait()
+
+	assert.NoError(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user) VALUES(?, ?)`))
+
+	var wg2 sync.WaitGroup
+	wg2.Add(1)
+	go func() {
+		defer wg2.Done()
+		for i := 0; i < 100; i++ {
+			assert.NoError(t, inserter.Insert("foo", "bar"))
+		}
+		inserter.Flush()
+	}()
+	wg2.Wait()
+
+	assert.Equal(t, int32(2), atomic.LoadInt32(&updated))
+}