1
0

bulkinserter.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. "sync"
  7. "time"
  8. "github.com/wuntsong-org/go-zero-plus/core/executors"
  9. "github.com/wuntsong-org/go-zero-plus/core/logx"
  10. "github.com/wuntsong-org/go-zero-plus/core/stringx"
  11. )
  12. const (
  13. flushInterval = time.Second
  14. maxBulkRows = 1000
  15. valuesKeyword = "values"
  16. )
  17. var emptyBulkStmt bulkStmt
  18. type (
  19. // ResultHandler defines the method of result handlers.
  20. ResultHandler func(sql.Result, error)
  21. // A BulkInserter is used to batch insert records.
  22. // Postgresql is not supported yet, because of the sql is formated with symbol `$`.
  23. // Oracle is not supported yet, because of the sql is formated with symbol `:`.
  24. BulkInserter struct {
  25. executor *executors.PeriodicalExecutor
  26. inserter *dbInserter
  27. stmt bulkStmt
  28. lock sync.RWMutex // guards stmt
  29. }
  30. bulkStmt struct {
  31. prefix string
  32. valueFormat string
  33. suffix string
  34. }
  35. )
  36. // NewBulkInserter returns a BulkInserter.
  37. func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) {
  38. bkStmt, err := parseInsertStmt(stmt)
  39. if err != nil {
  40. return nil, err
  41. }
  42. inserter := &dbInserter{
  43. sqlConn: sqlConn,
  44. stmt: bkStmt,
  45. }
  46. return &BulkInserter{
  47. executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
  48. inserter: inserter,
  49. stmt: bkStmt,
  50. }, nil
  51. }
  52. // Flush flushes all the pending records.
  53. func (bi *BulkInserter) Flush() {
  54. bi.executor.Flush()
  55. }
  56. // Insert inserts given args.
  57. func (bi *BulkInserter) Insert(args ...any) error {
  58. bi.lock.RLock()
  59. defer bi.lock.RUnlock()
  60. value, err := format(bi.stmt.valueFormat, args...)
  61. if err != nil {
  62. return err
  63. }
  64. bi.executor.Add(value)
  65. return nil
  66. }
  67. // SetResultHandler sets the given handler.
  68. func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
  69. bi.executor.Sync(func() {
  70. bi.inserter.resultHandler = handler
  71. })
  72. }
  73. // UpdateOrDelete runs update or delete queries, which flushes pending records first.
  74. func (bi *BulkInserter) UpdateOrDelete(fn func()) {
  75. bi.executor.Flush()
  76. fn()
  77. }
  78. // UpdateStmt updates the insert statement.
  79. func (bi *BulkInserter) UpdateStmt(stmt string) error {
  80. bkStmt, err := parseInsertStmt(stmt)
  81. if err != nil {
  82. return err
  83. }
  84. bi.lock.Lock()
  85. defer bi.lock.Unlock()
  86. // with write lock, it doesn't matter what's the order of setting bi.stmt and calling flush.
  87. bi.stmt = bkStmt
  88. bi.executor.Flush()
  89. bi.executor.Sync(func() {
  90. bi.inserter.stmt = bkStmt
  91. })
  92. return nil
  93. }
  94. type dbInserter struct {
  95. sqlConn SqlConn
  96. stmt bulkStmt
  97. values []string
  98. resultHandler ResultHandler
  99. }
  100. func (in *dbInserter) AddTask(task any) bool {
  101. in.values = append(in.values, task.(string))
  102. return len(in.values) >= maxBulkRows
  103. }
  104. func (in *dbInserter) Execute(bulk any) {
  105. values := bulk.([]string)
  106. if len(values) == 0 {
  107. return
  108. }
  109. stmtWithoutValues := in.stmt.prefix
  110. valuesStr := strings.Join(values, ", ")
  111. stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ")
  112. if len(in.stmt.suffix) > 0 {
  113. stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ")
  114. }
  115. result, err := in.sqlConn.Exec(stmt)
  116. if in.resultHandler != nil {
  117. in.resultHandler(result, err)
  118. } else if err != nil {
  119. logx.Errorf("sql: %s, error: %s", stmt, err)
  120. }
  121. }
  122. func (in *dbInserter) RemoveAll() any {
  123. values := in.values
  124. in.values = nil
  125. return values
  126. }
  127. func parseInsertStmt(stmt string) (bulkStmt, error) {
  128. lower := strings.ToLower(stmt)
  129. pos := strings.Index(lower, valuesKeyword)
  130. if pos <= 0 {
  131. return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt)
  132. }
  133. var columns int
  134. right := strings.LastIndexByte(lower[:pos], ')')
  135. if right > 0 {
  136. left := strings.LastIndexByte(lower[:right], '(')
  137. if left > 0 {
  138. values := lower[left+1 : right]
  139. values = stringx.Filter(values, func(r rune) bool {
  140. return r == ' ' || r == '\t' || r == '\r' || r == '\n'
  141. })
  142. fields := strings.FieldsFunc(values, func(r rune) bool {
  143. return r == ','
  144. })
  145. columns = len(fields)
  146. }
  147. }
  148. var variables int
  149. var valueFormat string
  150. var suffix string
  151. left := strings.IndexByte(lower[pos:], '(')
  152. if left > 0 {
  153. right = strings.IndexByte(lower[pos+left:], ')')
  154. if right > 0 {
  155. values := lower[pos+left : pos+left+right]
  156. for _, x := range values {
  157. if x == '?' {
  158. variables++
  159. }
  160. }
  161. valueFormat = stmt[pos+left : pos+left+right+1]
  162. suffix = strings.TrimSpace(stmt[pos+left+right+1:])
  163. }
  164. }
  165. if variables == 0 {
  166. return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt)
  167. }
  168. if columns > 0 && columns != variables {
  169. return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt)
  170. }
  171. return bulkStmt{
  172. prefix: stmt[:pos+len(valuesKeyword)],
  173. valueFormat: valueFormat,
  174. suffix: suffix,
  175. }, nil
  176. }