gen_test.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. package gen
  2. import (
  3. "database/sql"
  4. _ "embed"
  5. "io/ioutil"
  6. "os"
  7. "path"
  8. "path/filepath"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. "github.com/zeromicro/go-zero/core/logx"
  15. "github.com/zeromicro/go-zero/core/stores/builder"
  16. "github.com/zeromicro/go-zero/core/stringx"
  17. "github.com/zeromicro/go-zero/tools/goctl/config"
  18. "github.com/zeromicro/go-zero/tools/goctl/model/sql/parser"
  19. "github.com/zeromicro/go-zero/tools/goctl/util/pathx"
  20. )
  21. //go:embed testdata/user.sql
  22. var source string
  23. func TestCacheModel(t *testing.T) {
  24. logx.Disable()
  25. _ = Clean()
  26. sqlFile := filepath.Join(pathx.MustTempDir(), "tmp.sql")
  27. err := ioutil.WriteFile(sqlFile, []byte(source), 0o777)
  28. assert.Nil(t, err)
  29. dir := filepath.Join(pathx.MustTempDir(), "./testmodel")
  30. cacheDir := filepath.Join(dir, "cache")
  31. noCacheDir := filepath.Join(dir, "nocache")
  32. g, err := NewDefaultGenerator(cacheDir, &config.Config{
  33. NamingFormat: "GoZero",
  34. })
  35. assert.Nil(t, err)
  36. err = g.StartFromDDL(sqlFile, true, false, "go_zero")
  37. assert.Nil(t, err)
  38. assert.True(t, func() bool {
  39. _, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go"))
  40. return err == nil
  41. }())
  42. g, err = NewDefaultGenerator(noCacheDir, &config.Config{
  43. NamingFormat: "gozero",
  44. })
  45. assert.Nil(t, err)
  46. err = g.StartFromDDL(sqlFile, false, false, "go_zero")
  47. assert.Nil(t, err)
  48. assert.True(t, func() bool {
  49. _, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go"))
  50. return err == nil
  51. }())
  52. }
  53. func TestNamingModel(t *testing.T) {
  54. logx.Disable()
  55. _ = Clean()
  56. sqlFile := filepath.Join(pathx.MustTempDir(), "tmp.sql")
  57. err := ioutil.WriteFile(sqlFile, []byte(source), 0o777)
  58. assert.Nil(t, err)
  59. dir, _ := filepath.Abs("./testmodel")
  60. camelDir := filepath.Join(dir, "camel")
  61. snakeDir := filepath.Join(dir, "snake")
  62. defer func() {
  63. _ = os.RemoveAll(dir)
  64. }()
  65. g, err := NewDefaultGenerator(camelDir, &config.Config{
  66. NamingFormat: "GoZero",
  67. })
  68. assert.Nil(t, err)
  69. err = g.StartFromDDL(sqlFile, true, false, "go_zero")
  70. assert.Nil(t, err)
  71. assert.True(t, func() bool {
  72. _, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go"))
  73. return err == nil
  74. }())
  75. g, err = NewDefaultGenerator(snakeDir, &config.Config{
  76. NamingFormat: "go_zero",
  77. })
  78. assert.Nil(t, err)
  79. err = g.StartFromDDL(sqlFile, true, false, "go_zero")
  80. assert.Nil(t, err)
  81. assert.True(t, func() bool {
  82. _, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go"))
  83. return err == nil
  84. }())
  85. }
  86. func TestFolderName(t *testing.T) {
  87. logx.Disable()
  88. _ = Clean()
  89. sqlFile := filepath.Join(pathx.MustTempDir(), "tmp.sql")
  90. err := ioutil.WriteFile(sqlFile, []byte(source), 0o777)
  91. assert.Nil(t, err)
  92. dir, _ := filepath.Abs("./testmodel")
  93. camelDir := filepath.Join(dir, "go-camel")
  94. snakeDir := filepath.Join(dir, "go-snake")
  95. defer func() {
  96. _ = os.RemoveAll(dir)
  97. }()
  98. g, err := NewDefaultGenerator(camelDir, &config.Config{
  99. NamingFormat: "GoZero",
  100. })
  101. assert.Nil(t, err)
  102. pkg := g.pkg
  103. err = g.StartFromDDL(sqlFile, true, "go_zero")
  104. assert.Nil(t, err)
  105. assert.True(t, func() bool {
  106. _, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go"))
  107. return err == nil
  108. }())
  109. assert.Equal(t, pkg, g.pkg)
  110. g, err = NewDefaultGenerator(snakeDir, &config.Config{
  111. NamingFormat: "go_zero",
  112. })
  113. assert.Nil(t, err)
  114. err = g.StartFromDDL(sqlFile, true, "go_zero")
  115. assert.Nil(t, err)
  116. assert.True(t, func() bool {
  117. _, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go"))
  118. return err == nil
  119. }())
  120. }
  121. func TestWrapWithRawString(t *testing.T) {
  122. assert.Equal(t, "``", wrapWithRawString("", false))
  123. assert.Equal(t, "``", wrapWithRawString("``", false))
  124. assert.Equal(t, "`a`", wrapWithRawString("a", false))
  125. assert.Equal(t, "a", wrapWithRawString("a", true))
  126. assert.Equal(t, "` `", wrapWithRawString(" ", false))
  127. }
  128. func TestFields(t *testing.T) {
  129. type Student struct {
  130. ID int64 `db:"id"`
  131. Name string `db:"name"`
  132. Age sql.NullInt64 `db:"age"`
  133. Score sql.NullFloat64 `db:"score"`
  134. CreateTime time.Time `db:"create_time"`
  135. UpdateTime sql.NullTime `db:"update_time"`
  136. }
  137. var (
  138. studentFieldNames = builder.RawFieldNames(&Student{})
  139. studentRows = strings.Join(studentFieldNames, ",")
  140. studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
  141. studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
  142. )
  143. assert.Equal(t, []string{"`id`", "`name`", "`age`", "`score`", "`create_time`", "`update_time`"}, studentFieldNames)
  144. assert.Equal(t, "`id`,`name`,`age`,`score`,`create_time`,`update_time`", studentRows)
  145. assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet)
  146. assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder)
  147. }
  148. func Test_genPublicModel(t *testing.T) {
  149. var err error
  150. dir := pathx.MustTempDir()
  151. modelDir := path.Join(dir, "model")
  152. err = os.MkdirAll(modelDir, 0o777)
  153. require.NoError(t, err)
  154. defer os.RemoveAll(dir)
  155. modelFilename := filepath.Join(modelDir, "foo.sql")
  156. err = ioutil.WriteFile(modelFilename, []byte(source), 0o777)
  157. require.NoError(t, err)
  158. g, err := NewDefaultGenerator(modelDir, &config.Config{
  159. NamingFormat: config.DefaultFormat,
  160. })
  161. require.NoError(t, err)
  162. tables, err := parser.Parse(modelFilename, "", false)
  163. require.Equal(t, 1, len(tables))
  164. code, err := g.genModelCustom(*tables[0], false)
  165. assert.NoError(t, err)
  166. assert.True(t, strings.Contains(code, "package model"))
  167. assert.True(t, strings.Contains(code, "TestUserModel interface {\n\t\ttestUserModel\n\t}\n"))
  168. assert.True(t, strings.Contains(code, "customTestUserModel struct {\n\t\t*defaultTestUserModel\n\t}\n"))
  169. assert.True(t, strings.Contains(code, "func NewTestUserModel(conn sqlx.SqlConn) TestUserModel {"))
  170. }