gen_test.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package gen
  2. import (
  3. "database/sql"
  4. _ "embed"
  5. "os"
  6. "path"
  7. "path/filepath"
  8. "strings"
  9. "testing"
  10. "time"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. "github.com/zeromicro/go-zero/core/logx"
  14. "github.com/zeromicro/go-zero/core/stores/builder"
  15. "github.com/zeromicro/go-zero/core/stringx"
  16. "github.com/zeromicro/go-zero/tools/goctl/config"
  17. "github.com/zeromicro/go-zero/tools/goctl/model/sql/parser"
  18. "github.com/zeromicro/go-zero/tools/goctl/util/pathx"
  19. )
  20. //go:embed testdata/user.sql
  21. var source string
  22. func TestCacheModel(t *testing.T) {
  23. logx.Disable()
  24. _ = Clean()
  25. sqlFile := filepath.Join(pathx.MustTempDir(), "tmp.sql")
  26. err := os.WriteFile(sqlFile, []byte(source), 0o777)
  27. assert.Nil(t, err)
  28. dir := filepath.Join(pathx.MustTempDir(), "./testmodel")
  29. cacheDir := filepath.Join(dir, "cache")
  30. noCacheDir := filepath.Join(dir, "nocache")
  31. g, err := NewDefaultGenerator(cacheDir, &config.Config{
  32. NamingFormat: "GoZero",
  33. })
  34. assert.Nil(t, err)
  35. err = g.StartFromDDL(sqlFile, true, false, "go_zero")
  36. assert.Nil(t, err)
  37. assert.True(t, func() bool {
  38. _, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go"))
  39. return err == nil
  40. }())
  41. g, err = NewDefaultGenerator(noCacheDir, &config.Config{
  42. NamingFormat: "gozero",
  43. })
  44. assert.Nil(t, err)
  45. err = g.StartFromDDL(sqlFile, false, false, "go_zero")
  46. assert.Nil(t, err)
  47. assert.True(t, func() bool {
  48. _, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go"))
  49. return err == nil
  50. }())
  51. }
  52. func TestNamingModel(t *testing.T) {
  53. logx.Disable()
  54. _ = Clean()
  55. sqlFile := filepath.Join(pathx.MustTempDir(), "tmp.sql")
  56. err := os.WriteFile(sqlFile, []byte(source), 0o777)
  57. assert.Nil(t, err)
  58. dir, _ := filepath.Abs("./testmodel")
  59. camelDir := filepath.Join(dir, "camel")
  60. snakeDir := filepath.Join(dir, "snake")
  61. defer func() {
  62. _ = os.RemoveAll(dir)
  63. }()
  64. g, err := NewDefaultGenerator(camelDir, &config.Config{
  65. NamingFormat: "GoZero",
  66. })
  67. assert.Nil(t, err)
  68. err = g.StartFromDDL(sqlFile, true, false, "go_zero")
  69. assert.Nil(t, err)
  70. assert.True(t, func() bool {
  71. _, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go"))
  72. return err == nil
  73. }())
  74. g, err = NewDefaultGenerator(snakeDir, &config.Config{
  75. NamingFormat: "go_zero",
  76. })
  77. assert.Nil(t, err)
  78. err = g.StartFromDDL(sqlFile, true, false, "go_zero")
  79. assert.Nil(t, err)
  80. assert.True(t, func() bool {
  81. _, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go"))
  82. return err == nil
  83. }())
  84. }
  85. func TestFolderName(t *testing.T) {
  86. logx.Disable()
  87. _ = Clean()
  88. sqlFile := filepath.Join(pathx.MustTempDir(), "tmp.sql")
  89. err := os.WriteFile(sqlFile, []byte(source), 0o777)
  90. assert.Nil(t, err)
  91. dir, _ := filepath.Abs("./testmodel")
  92. camelDir := filepath.Join(dir, "go-camel")
  93. snakeDir := filepath.Join(dir, "go-snake")
  94. defer func() {
  95. _ = os.RemoveAll(dir)
  96. }()
  97. g, err := NewDefaultGenerator(camelDir, &config.Config{
  98. NamingFormat: "GoZero",
  99. })
  100. assert.Nil(t, err)
  101. pkg := g.pkg
  102. err = g.StartFromDDL(sqlFile, true, true, "go_zero")
  103. assert.Nil(t, err)
  104. assert.True(t, func() bool {
  105. _, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go"))
  106. return err == nil
  107. }())
  108. assert.Equal(t, pkg, g.pkg)
  109. g, err = NewDefaultGenerator(snakeDir, &config.Config{
  110. NamingFormat: "go_zero",
  111. })
  112. assert.Nil(t, err)
  113. err = g.StartFromDDL(sqlFile, true, true, "go_zero")
  114. assert.Nil(t, err)
  115. assert.True(t, func() bool {
  116. _, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go"))
  117. return err == nil
  118. }())
  119. }
  120. func TestWrapWithRawString(t *testing.T) {
  121. assert.Equal(t, "``", wrapWithRawString("", false))
  122. assert.Equal(t, "``", wrapWithRawString("``", false))
  123. assert.Equal(t, "`a`", wrapWithRawString("a", false))
  124. assert.Equal(t, "a", wrapWithRawString("a", true))
  125. assert.Equal(t, "` `", wrapWithRawString(" ", false))
  126. }
  127. func TestFields(t *testing.T) {
  128. type Student struct {
  129. ID int64 `db:"id"`
  130. Name string `db:"name"`
  131. Age sql.NullInt64 `db:"age"`
  132. Score sql.NullFloat64 `db:"score"`
  133. CreateTime time.Time `db:"create_time"`
  134. UpdateTime sql.NullTime `db:"update_time"`
  135. }
  136. var (
  137. studentFieldNames = builder.RawFieldNames(&Student{})
  138. studentRows = strings.Join(studentFieldNames, ",")
  139. studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
  140. studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
  141. )
  142. assert.Equal(t, []string{"`id`", "`name`", "`age`", "`score`", "`create_time`", "`update_time`"}, studentFieldNames)
  143. assert.Equal(t, "`id`,`name`,`age`,`score`,`create_time`,`update_time`", studentRows)
  144. assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet)
  145. assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder)
  146. }
  147. func Test_genPublicModel(t *testing.T) {
  148. var err error
  149. dir := pathx.MustTempDir()
  150. modelDir := path.Join(dir, "model")
  151. err = os.MkdirAll(modelDir, 0o777)
  152. require.NoError(t, err)
  153. defer os.RemoveAll(dir)
  154. modelFilename := filepath.Join(modelDir, "foo.sql")
  155. err = os.WriteFile(modelFilename, []byte(source), 0o777)
  156. require.NoError(t, err)
  157. g, err := NewDefaultGenerator(modelDir, &config.Config{
  158. NamingFormat: config.DefaultFormat,
  159. })
  160. require.NoError(t, err)
  161. tables, err := parser.Parse(modelFilename, "", false)
  162. require.Equal(t, 1, len(tables))
  163. code, err := g.genModelCustom(*tables[0], false)
  164. assert.NoError(t, err)
  165. assert.True(t, strings.Contains(code, "package model"))
  166. assert.True(t, strings.Contains(code, "TestUserModel interface {\n\t\ttestUserModel\n\t}\n"))
  167. assert.True(t, strings.Contains(code, "customTestUserModel struct {\n\t\t*defaultTestUserModel\n\t}\n"))
  168. assert.True(t, strings.Contains(code, "func NewTestUserModel(conn sqlx.SqlConn) TestUserModel {"))
  169. }