bulkinserter.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. "time"
  7. "github.com/tal-tech/go-zero/core/executors"
  8. "github.com/tal-tech/go-zero/core/logx"
  9. "github.com/tal-tech/go-zero/core/stringx"
  10. )
  11. const (
  12. flushInterval = time.Second
  13. maxBulkRows = 1000
  14. valuesKeyword = "values"
  15. )
  16. var emptyBulkStmt bulkStmt
  17. type (
  18. // ResultHandler defines the method of result handlers.
  19. ResultHandler func(sql.Result, error)
  20. // A BulkInserter is used to batch insert records.
  21. BulkInserter struct {
  22. executor *executors.PeriodicalExecutor
  23. inserter *dbInserter
  24. stmt bulkStmt
  25. }
  26. bulkStmt struct {
  27. prefix string
  28. valueFormat string
  29. suffix string
  30. }
  31. )
  32. // NewBulkInserter returns a BulkInserter.
  33. func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) {
  34. bkStmt, err := parseInsertStmt(stmt)
  35. if err != nil {
  36. return nil, err
  37. }
  38. inserter := &dbInserter{
  39. sqlConn: sqlConn,
  40. stmt: bkStmt,
  41. }
  42. return &BulkInserter{
  43. executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
  44. inserter: inserter,
  45. stmt: bkStmt,
  46. }, nil
  47. }
  48. // Flush flushes all the pending records.
  49. func (bi *BulkInserter) Flush() {
  50. bi.executor.Flush()
  51. }
  52. // Insert inserts given args.
  53. func (bi *BulkInserter) Insert(args ...interface{}) error {
  54. value, err := format(bi.stmt.valueFormat, args...)
  55. if err != nil {
  56. return err
  57. }
  58. bi.executor.Add(value)
  59. return nil
  60. }
  61. // SetResultHandler sets the given handler.
  62. func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
  63. bi.executor.Sync(func() {
  64. bi.inserter.resultHandler = handler
  65. })
  66. }
  67. // UpdateOrDelete runs update or delete queries, which flushes pending records first.
  68. func (bi *BulkInserter) UpdateOrDelete(fn func()) {
  69. bi.executor.Flush()
  70. fn()
  71. }
  72. // UpdateStmt updates the insert statement.
  73. func (bi *BulkInserter) UpdateStmt(stmt string) error {
  74. bkStmt, err := parseInsertStmt(stmt)
  75. if err != nil {
  76. return err
  77. }
  78. bi.executor.Flush()
  79. bi.executor.Sync(func() {
  80. bi.inserter.stmt = bkStmt
  81. })
  82. return nil
  83. }
  84. type dbInserter struct {
  85. sqlConn SqlConn
  86. stmt bulkStmt
  87. values []string
  88. resultHandler ResultHandler
  89. }
  90. func (in *dbInserter) AddTask(task interface{}) bool {
  91. in.values = append(in.values, task.(string))
  92. return len(in.values) >= maxBulkRows
  93. }
  94. func (in *dbInserter) Execute(bulk interface{}) {
  95. values := bulk.([]string)
  96. if len(values) == 0 {
  97. return
  98. }
  99. stmtWithoutValues := in.stmt.prefix
  100. valuesStr := strings.Join(values, ", ")
  101. stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ")
  102. if len(in.stmt.suffix) > 0 {
  103. stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ")
  104. }
  105. result, err := in.sqlConn.Exec(stmt)
  106. if in.resultHandler != nil {
  107. in.resultHandler(result, err)
  108. } else if err != nil {
  109. logx.Errorf("sql: %s, error: %s", stmt, err)
  110. }
  111. }
  112. func (in *dbInserter) RemoveAll() interface{} {
  113. values := in.values
  114. in.values = nil
  115. return values
  116. }
  117. func parseInsertStmt(stmt string) (bulkStmt, error) {
  118. lower := strings.ToLower(stmt)
  119. pos := strings.Index(lower, valuesKeyword)
  120. if pos <= 0 {
  121. return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt)
  122. }
  123. var columns int
  124. right := strings.LastIndexByte(lower[:pos], ')')
  125. if right > 0 {
  126. left := strings.LastIndexByte(lower[:right], '(')
  127. if left > 0 {
  128. values := lower[left+1 : right]
  129. values = stringx.Filter(values, func(r rune) bool {
  130. return r == ' ' || r == '\t' || r == '\r' || r == '\n'
  131. })
  132. fields := strings.FieldsFunc(values, func(r rune) bool {
  133. return r == ','
  134. })
  135. columns = len(fields)
  136. }
  137. }
  138. var variables int
  139. var valueFormat string
  140. var suffix string
  141. left := strings.IndexByte(lower[pos:], '(')
  142. if left > 0 {
  143. right = strings.IndexByte(lower[pos+left:], ')')
  144. if right > 0 {
  145. values := lower[pos+left : pos+left+right]
  146. for _, x := range values {
  147. if x == '?' {
  148. variables++
  149. }
  150. }
  151. valueFormat = stmt[pos+left : pos+left+right+1]
  152. suffix = strings.TrimSpace(stmt[pos+left+right+1:])
  153. }
  154. }
  155. if variables == 0 {
  156. return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt)
  157. }
  158. if columns > 0 && columns != variables {
  159. return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt)
  160. }
  161. return bulkStmt{
  162. prefix: stmt[:pos+len(valuesKeyword)],
  163. valueFormat: valueFormat,
  164. suffix: suffix,
  165. }, nil
  166. }