utils.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package sqlx
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strconv"
  7. "strings"
  8. "github.com/zeromicro/go-zero/core/logx"
  9. "github.com/zeromicro/go-zero/core/mapping"
  10. )
  11. var errUnbalancedEscape = errors.New("no char after escape char")
  12. func desensitize(datasource string) string {
  13. // remove account
  14. pos := strings.LastIndex(datasource, "@")
  15. if 0 <= pos && pos+1 < len(datasource) {
  16. datasource = datasource[pos+1:]
  17. }
  18. return datasource
  19. }
  20. func escape(input string) string {
  21. var b strings.Builder
  22. for _, ch := range input {
  23. switch ch {
  24. case '\x00':
  25. b.WriteString(`\x00`)
  26. case '\r':
  27. b.WriteString(`\r`)
  28. case '\n':
  29. b.WriteString(`\n`)
  30. case '\\':
  31. b.WriteString(`\\`)
  32. case '\'':
  33. b.WriteString(`\'`)
  34. case '"':
  35. b.WriteString(`\"`)
  36. case '\x1a':
  37. b.WriteString(`\x1a`)
  38. default:
  39. b.WriteRune(ch)
  40. }
  41. }
  42. return b.String()
  43. }
  44. func format(query string, args ...interface{}) (string, error) {
  45. numArgs := len(args)
  46. if numArgs == 0 {
  47. return query, nil
  48. }
  49. var b strings.Builder
  50. var argIndex int
  51. bytes := len(query)
  52. for i := 0; i < bytes; i++ {
  53. ch := query[i]
  54. switch ch {
  55. case '?':
  56. if argIndex >= numArgs {
  57. return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
  58. }
  59. writeValue(&b, args[argIndex])
  60. argIndex++
  61. case ':', '$':
  62. var j int
  63. for j = i + 1; j < bytes; j++ {
  64. char := query[j]
  65. if char < '0' || '9' < char {
  66. break
  67. }
  68. }
  69. if j > i+1 {
  70. index, err := strconv.Atoi(query[i+1 : j])
  71. if err != nil {
  72. return "", err
  73. }
  74. // index starts from 1 for pg or oracle
  75. if index > argIndex {
  76. argIndex = index
  77. }
  78. index--
  79. if index < 0 || numArgs <= index {
  80. return "", fmt.Errorf("error: wrong index %d in sql", index)
  81. }
  82. writeValue(&b, args[index])
  83. i = j - 1
  84. }
  85. case '\'', '"', '`':
  86. b.WriteByte(ch)
  87. for j := i + 1; j < bytes; j++ {
  88. cur := query[j]
  89. b.WriteByte(cur)
  90. switch cur {
  91. case '\\':
  92. j++
  93. if j >= bytes {
  94. return "", errUnbalancedEscape
  95. }
  96. b.WriteByte(query[j])
  97. case '\'', '"', '`':
  98. if cur == ch {
  99. i = j
  100. goto end
  101. }
  102. }
  103. }
  104. end:
  105. break
  106. default:
  107. b.WriteByte(ch)
  108. }
  109. }
  110. if argIndex < numArgs {
  111. return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
  112. }
  113. return b.String(), nil
  114. }
  115. func logInstanceError(datasource string, err error) {
  116. datasource = desensitize(datasource)
  117. logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
  118. }
  119. func logSqlError(ctx context.Context, stmt string, err error) {
  120. if err != nil && err != ErrNotFound {
  121. logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
  122. }
  123. }
  124. func writeValue(buf *strings.Builder, arg interface{}) {
  125. switch v := arg.(type) {
  126. case bool:
  127. if v {
  128. buf.WriteByte('1')
  129. } else {
  130. buf.WriteByte('0')
  131. }
  132. case string:
  133. buf.WriteByte('\'')
  134. buf.WriteString(escape(v))
  135. buf.WriteByte('\'')
  136. default:
  137. buf.WriteString(mapping.Repr(v))
  138. }
  139. }