utils.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. package sqlx
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/zeromicro/go-zero/core/logx"
  10. "github.com/zeromicro/go-zero/core/mapping"
  11. )
  12. var errUnbalancedEscape = errors.New("no char after escape char")
  13. func desensitize(datasource string) string {
  14. // remove account
  15. pos := strings.LastIndex(datasource, "@")
  16. if 0 <= pos && pos+1 < len(datasource) {
  17. datasource = datasource[pos+1:]
  18. }
  19. return datasource
  20. }
  21. func escape(input string) string {
  22. var b strings.Builder
  23. for _, ch := range input {
  24. switch ch {
  25. case '\x00':
  26. b.WriteString(`\x00`)
  27. case '\r':
  28. b.WriteString(`\r`)
  29. case '\n':
  30. b.WriteString(`\n`)
  31. case '\\':
  32. b.WriteString(`\\`)
  33. case '\'':
  34. b.WriteString(`\'`)
  35. case '"':
  36. b.WriteString(`\"`)
  37. case '\x1a':
  38. b.WriteString(`\x1a`)
  39. default:
  40. b.WriteRune(ch)
  41. }
  42. }
  43. return b.String()
  44. }
  45. func format(query string, args ...any) (val string, err error) {
  46. defer func() {
  47. if err != nil {
  48. err = newAcceptableError(err)
  49. }
  50. }()
  51. numArgs := len(args)
  52. if numArgs == 0 {
  53. return query, nil
  54. }
  55. var b strings.Builder
  56. var argIndex int
  57. bytes := len(query)
  58. for i := 0; i < bytes; i++ {
  59. ch := query[i]
  60. switch ch {
  61. case '?':
  62. if argIndex >= numArgs {
  63. return "", fmt.Errorf("%d ? in sql, but only %d arguments provided",
  64. argIndex+1, numArgs)
  65. }
  66. writeValue(&b, args[argIndex])
  67. argIndex++
  68. case ':', '$':
  69. var j int
  70. for j = i + 1; j < bytes; j++ {
  71. char := query[j]
  72. if char < '0' || '9' < char {
  73. break
  74. }
  75. }
  76. if j > i+1 {
  77. index, err := strconv.Atoi(query[i+1 : j])
  78. if err != nil {
  79. return "", err
  80. }
  81. // index starts from 1 for pg or oracle
  82. if index > argIndex {
  83. argIndex = index
  84. }
  85. index--
  86. if index < 0 || numArgs <= index {
  87. return "", fmt.Errorf("wrong index %d in sql", index)
  88. }
  89. writeValue(&b, args[index])
  90. i = j - 1
  91. }
  92. case '\'', '"', '`':
  93. b.WriteByte(ch)
  94. for j := i + 1; j < bytes; j++ {
  95. cur := query[j]
  96. b.WriteByte(cur)
  97. if cur == '\\' {
  98. j++
  99. if j >= bytes {
  100. return "", errUnbalancedEscape
  101. }
  102. b.WriteByte(query[j])
  103. } else if cur == ch {
  104. i = j
  105. break
  106. }
  107. }
  108. default:
  109. b.WriteByte(ch)
  110. }
  111. }
  112. if argIndex < numArgs {
  113. return "", fmt.Errorf("%d arguments provided, not matching sql", argIndex)
  114. }
  115. return b.String(), nil
  116. }
  117. func logInstanceError(ctx context.Context, datasource string, err error) {
  118. datasource = desensitize(datasource)
  119. logx.WithContext(ctx).Errorf("Error on getting sql instance of %s: %v", datasource, err)
  120. }
  121. func logSqlError(ctx context.Context, stmt string, err error) {
  122. if err != nil && err != ErrNotFound {
  123. logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
  124. }
  125. }
  126. func writeValue(buf *strings.Builder, arg any) {
  127. switch v := arg.(type) {
  128. case bool:
  129. if v {
  130. buf.WriteByte('1')
  131. } else {
  132. buf.WriteByte('0')
  133. }
  134. case string:
  135. buf.WriteByte('\'')
  136. buf.WriteString(escape(v))
  137. buf.WriteByte('\'')
  138. case time.Time:
  139. buf.WriteByte('\'')
  140. buf.WriteString(v.String())
  141. buf.WriteByte('\'')
  142. case *time.Time:
  143. buf.WriteByte('\'')
  144. buf.WriteString(v.String())
  145. buf.WriteByte('\'')
  146. default:
  147. buf.WriteString(mapping.Repr(v))
  148. }
  149. }
  150. type acceptableError struct {
  151. err error
  152. }
  153. func newAcceptableError(err error) error {
  154. return acceptableError{
  155. err: err,
  156. }
  157. }
  158. func (e acceptableError) Error() string {
  159. return e.err.Error()
  160. }