utils.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package sqlx
  2. import (
  3. "fmt"
  4. "strings"
  5. "github.com/tal-tech/go-zero/core/logx"
  6. "github.com/tal-tech/go-zero/core/mapping"
  7. )
  8. func desensitize(datasource string) string {
  9. // remove account
  10. pos := strings.LastIndex(datasource, "@")
  11. if 0 <= pos && pos+1 < len(datasource) {
  12. datasource = datasource[pos+1:]
  13. }
  14. return datasource
  15. }
  16. func escape(input string) string {
  17. var b strings.Builder
  18. for _, ch := range input {
  19. switch ch {
  20. case '\x00':
  21. b.WriteString(`\x00`)
  22. case '\r':
  23. b.WriteString(`\r`)
  24. case '\n':
  25. b.WriteString(`\n`)
  26. case '\\':
  27. b.WriteString(`\\`)
  28. case '\'':
  29. b.WriteString(`\'`)
  30. case '"':
  31. b.WriteString(`\"`)
  32. case '\x1a':
  33. b.WriteString(`\x1a`)
  34. default:
  35. b.WriteRune(ch)
  36. }
  37. }
  38. return b.String()
  39. }
  40. func formatForPrint(query string, args ...interface{}) string {
  41. if len(args) == 0 {
  42. return query
  43. }
  44. var vals []string
  45. for _, arg := range args {
  46. vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
  47. }
  48. var b strings.Builder
  49. b.WriteByte('[')
  50. b.WriteString(strings.Join(vals, ", "))
  51. b.WriteByte(']')
  52. return strings.Join([]string{query, b.String()}, " ")
  53. }
  54. func format(query string, args ...interface{}) (string, error) {
  55. numArgs := len(args)
  56. if numArgs == 0 {
  57. return query, nil
  58. }
  59. var b strings.Builder
  60. argIndex := 0
  61. for _, ch := range query {
  62. if ch == '?' {
  63. if argIndex >= numArgs {
  64. return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
  65. }
  66. arg := args[argIndex]
  67. argIndex++
  68. switch v := arg.(type) {
  69. case bool:
  70. if v {
  71. b.WriteByte('1')
  72. } else {
  73. b.WriteByte('0')
  74. }
  75. case string:
  76. b.WriteByte('\'')
  77. b.WriteString(escape(v))
  78. b.WriteByte('\'')
  79. default:
  80. b.WriteString(mapping.Repr(v))
  81. }
  82. } else {
  83. b.WriteRune(ch)
  84. }
  85. }
  86. if argIndex < numArgs {
  87. return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
  88. }
  89. return b.String(), nil
  90. }
  91. func logInstanceError(datasource string, err error) {
  92. datasource = desensitize(datasource)
  93. logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
  94. }
  95. func logSqlError(stmt string, err error) {
  96. if err != nil && err != ErrNotFound {
  97. logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
  98. }
  99. }