bulkinserter.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. "time"
  7. "github.com/zeromicro/go-zero/core/executors"
  8. "github.com/zeromicro/go-zero/core/logx"
  9. "github.com/zeromicro/go-zero/core/stringx"
  10. )
  11. const (
  12. flushInterval = time.Second
  13. maxBulkRows = 1000
  14. valuesKeyword = "values"
  15. )
  16. var emptyBulkStmt bulkStmt
  17. type (
  18. // ResultHandler defines the method of result handlers.
  19. ResultHandler func(sql.Result, error)
  20. // A BulkInserter is used to batch insert records.
  21. // Postgresql is not supported yet, because of the sql is formated with symbol `$`.
  22. // Oracle is not supported yet, because of the sql is formated with symbol `:`.
  23. BulkInserter struct {
  24. executor *executors.PeriodicalExecutor
  25. inserter *dbInserter
  26. stmt bulkStmt
  27. }
  28. bulkStmt struct {
  29. prefix string
  30. valueFormat string
  31. suffix string
  32. }
  33. )
  34. // NewBulkInserter returns a BulkInserter.
  35. func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) {
  36. bkStmt, err := parseInsertStmt(stmt)
  37. if err != nil {
  38. return nil, err
  39. }
  40. inserter := &dbInserter{
  41. sqlConn: sqlConn,
  42. stmt: bkStmt,
  43. }
  44. return &BulkInserter{
  45. executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
  46. inserter: inserter,
  47. stmt: bkStmt,
  48. }, nil
  49. }
  50. // Flush flushes all the pending records.
  51. func (bi *BulkInserter) Flush() {
  52. bi.executor.Flush()
  53. }
  54. // Insert inserts given args.
  55. func (bi *BulkInserter) Insert(args ...any) error {
  56. value, err := format(bi.stmt.valueFormat, args...)
  57. if err != nil {
  58. return err
  59. }
  60. bi.executor.Add(value)
  61. return nil
  62. }
  63. // SetResultHandler sets the given handler.
  64. func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
  65. bi.executor.Sync(func() {
  66. bi.inserter.resultHandler = handler
  67. })
  68. }
  69. // UpdateOrDelete runs update or delete queries, which flushes pending records first.
  70. func (bi *BulkInserter) UpdateOrDelete(fn func()) {
  71. bi.executor.Flush()
  72. fn()
  73. }
  74. // UpdateStmt updates the insert statement.
  75. func (bi *BulkInserter) UpdateStmt(stmt string) error {
  76. bkStmt, err := parseInsertStmt(stmt)
  77. if err != nil {
  78. return err
  79. }
  80. bi.executor.Flush()
  81. bi.executor.Sync(func() {
  82. bi.inserter.stmt = bkStmt
  83. })
  84. return nil
  85. }
  86. type dbInserter struct {
  87. sqlConn SqlConn
  88. stmt bulkStmt
  89. values []string
  90. resultHandler ResultHandler
  91. }
  92. func (in *dbInserter) AddTask(task any) bool {
  93. in.values = append(in.values, task.(string))
  94. return len(in.values) >= maxBulkRows
  95. }
  96. func (in *dbInserter) Execute(bulk any) {
  97. values := bulk.([]string)
  98. if len(values) == 0 {
  99. return
  100. }
  101. stmtWithoutValues := in.stmt.prefix
  102. valuesStr := strings.Join(values, ", ")
  103. stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ")
  104. if len(in.stmt.suffix) > 0 {
  105. stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ")
  106. }
  107. result, err := in.sqlConn.Exec(stmt)
  108. if in.resultHandler != nil {
  109. in.resultHandler(result, err)
  110. } else if err != nil {
  111. logx.Errorf("sql: %s, error: %s", stmt, err)
  112. }
  113. }
  114. func (in *dbInserter) RemoveAll() any {
  115. values := in.values
  116. in.values = nil
  117. return values
  118. }
  119. func parseInsertStmt(stmt string) (bulkStmt, error) {
  120. lower := strings.ToLower(stmt)
  121. pos := strings.Index(lower, valuesKeyword)
  122. if pos <= 0 {
  123. return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt)
  124. }
  125. var columns int
  126. right := strings.LastIndexByte(lower[:pos], ')')
  127. if right > 0 {
  128. left := strings.LastIndexByte(lower[:right], '(')
  129. if left > 0 {
  130. values := lower[left+1 : right]
  131. values = stringx.Filter(values, func(r rune) bool {
  132. return r == ' ' || r == '\t' || r == '\r' || r == '\n'
  133. })
  134. fields := strings.FieldsFunc(values, func(r rune) bool {
  135. return r == ','
  136. })
  137. columns = len(fields)
  138. }
  139. }
  140. var variables int
  141. var valueFormat string
  142. var suffix string
  143. left := strings.IndexByte(lower[pos:], '(')
  144. if left > 0 {
  145. right = strings.IndexByte(lower[pos+left:], ')')
  146. if right > 0 {
  147. values := lower[pos+left : pos+left+right]
  148. for _, x := range values {
  149. if x == '?' {
  150. variables++
  151. }
  152. }
  153. valueFormat = stmt[pos+left : pos+left+right+1]
  154. suffix = strings.TrimSpace(stmt[pos+left+right+1:])
  155. }
  156. }
  157. if variables == 0 {
  158. return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt)
  159. }
  160. if columns > 0 && columns != variables {
  161. return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt)
  162. }
  163. return bulkStmt{
  164. prefix: stmt[:pos+len(valuesKeyword)],
  165. valueFormat: valueFormat,
  166. suffix: suffix,
  167. }, nil
  168. }