utils.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package sqlx
  2. import (
  3. "context"
  4. "fmt"
  5. "strconv"
  6. "strings"
  7. "github.com/zeromicro/go-zero/core/logx"
  8. "github.com/zeromicro/go-zero/core/mapping"
  9. )
  10. func desensitize(datasource string) string {
  11. // remove account
  12. pos := strings.LastIndex(datasource, "@")
  13. if 0 <= pos && pos+1 < len(datasource) {
  14. datasource = datasource[pos+1:]
  15. }
  16. return datasource
  17. }
  18. func escape(input string) string {
  19. var b strings.Builder
  20. for _, ch := range input {
  21. switch ch {
  22. case '\x00':
  23. b.WriteString(`\x00`)
  24. case '\r':
  25. b.WriteString(`\r`)
  26. case '\n':
  27. b.WriteString(`\n`)
  28. case '\\':
  29. b.WriteString(`\\`)
  30. case '\'':
  31. b.WriteString(`\'`)
  32. case '"':
  33. b.WriteString(`\"`)
  34. case '\x1a':
  35. b.WriteString(`\x1a`)
  36. default:
  37. b.WriteRune(ch)
  38. }
  39. }
  40. return b.String()
  41. }
  42. func format(query string, args ...interface{}) (string, error) {
  43. numArgs := len(args)
  44. if numArgs == 0 {
  45. return query, nil
  46. }
  47. var b strings.Builder
  48. var argIndex int
  49. bytes := len(query)
  50. for i := 0; i < bytes; i++ {
  51. ch := query[i]
  52. switch ch {
  53. case '?':
  54. if argIndex >= numArgs {
  55. return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
  56. }
  57. writeValue(&b, args[argIndex])
  58. argIndex++
  59. case '$':
  60. var j int
  61. for j = i + 1; j < bytes; j++ {
  62. char := query[j]
  63. if char < '0' || '9' < char {
  64. break
  65. }
  66. }
  67. if j > i+1 {
  68. index, err := strconv.Atoi(query[i+1 : j])
  69. if err != nil {
  70. return "", err
  71. }
  72. // index starts from 1 for pg
  73. if index > argIndex {
  74. argIndex = index
  75. }
  76. index--
  77. if index < 0 || numArgs <= index {
  78. return "", fmt.Errorf("error: wrong index %d in sql", index)
  79. }
  80. writeValue(&b, args[index])
  81. i = j - 1
  82. }
  83. default:
  84. b.WriteByte(ch)
  85. }
  86. }
  87. if argIndex < numArgs {
  88. return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
  89. }
  90. return b.String(), nil
  91. }
  92. func logInstanceError(datasource string, err error) {
  93. datasource = desensitize(datasource)
  94. logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
  95. }
  96. func logSqlError(ctx context.Context, stmt string, err error) {
  97. if err != nil && err != ErrNotFound {
  98. logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
  99. }
  100. }
  101. func writeValue(buf *strings.Builder, arg interface{}) {
  102. switch v := arg.(type) {
  103. case bool:
  104. if v {
  105. buf.WriteByte('1')
  106. } else {
  107. buf.WriteByte('0')
  108. }
  109. case string:
  110. buf.WriteByte('\'')
  111. buf.WriteString(escape(v))
  112. buf.WriteByte('\'')
  113. default:
  114. buf.WriteString(mapping.Repr(v))
  115. }
  116. }