utils.go 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. // copy from core/stores/sqlx/utils.go
  2. package mocksql
  3. import (
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "github.com/zeromicro/go-zero/core/logx"
  8. "github.com/zeromicro/go-zero/core/mapping"
  9. )
  10. // ErrNotFound is the alias of sql.ErrNoRows
  11. var ErrNotFound = sql.ErrNoRows
  12. func escape(input string) string {
  13. var b strings.Builder
  14. for _, ch := range input {
  15. switch ch {
  16. case '\x00':
  17. b.WriteString(`\x00`)
  18. case '\r':
  19. b.WriteString(`\r`)
  20. case '\n':
  21. b.WriteString(`\n`)
  22. case '\\':
  23. b.WriteString(`\\`)
  24. case '\'':
  25. b.WriteString(`\'`)
  26. case '"':
  27. b.WriteString(`\"`)
  28. case '\x1a':
  29. b.WriteString(`\x1a`)
  30. default:
  31. b.WriteRune(ch)
  32. }
  33. }
  34. return b.String()
  35. }
  36. func format(query string, args ...any) (string, error) {
  37. numArgs := len(args)
  38. if numArgs == 0 {
  39. return query, nil
  40. }
  41. var b strings.Builder
  42. argIndex := 0
  43. for _, ch := range query {
  44. if ch == '?' {
  45. if argIndex >= numArgs {
  46. return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
  47. }
  48. arg := args[argIndex]
  49. argIndex++
  50. switch v := arg.(type) {
  51. case bool:
  52. if v {
  53. b.WriteByte('1')
  54. } else {
  55. b.WriteByte('0')
  56. }
  57. case string:
  58. b.WriteByte('\'')
  59. b.WriteString(escape(v))
  60. b.WriteByte('\'')
  61. default:
  62. b.WriteString(mapping.Repr(v))
  63. }
  64. } else {
  65. b.WriteRune(ch)
  66. }
  67. }
  68. if argIndex < numArgs {
  69. return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
  70. }
  71. return b.String(), nil
  72. }
  73. func logSqlError(stmt string, err error) {
  74. if err != nil && err != ErrNotFound {
  75. logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
  76. }
  77. }