gen_test.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. package gen
  2. import (
  3. "database/sql"
  4. "io/ioutil"
  5. "os"
  6. "path"
  7. "path/filepath"
  8. "strings"
  9. "testing"
  10. "time"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. "github.com/zeromicro/go-zero/core/logx"
  14. "github.com/zeromicro/go-zero/core/stringx"
  15. "github.com/zeromicro/go-zero/tools/goctl/config"
  16. "github.com/zeromicro/go-zero/tools/goctl/model/sql/builderx"
  17. "github.com/zeromicro/go-zero/tools/goctl/model/sql/parser"
  18. "github.com/zeromicro/go-zero/tools/goctl/util/pathx"
  19. )
  20. var source = "CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL,\n `class` bigint NOT NULL,\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"
  21. func TestCacheModel(t *testing.T) {
  22. logx.Disable()
  23. _ = Clean()
  24. sqlFile := filepath.Join(pathx.MustTempDir(), "tmp.sql")
  25. err := ioutil.WriteFile(sqlFile, []byte(source), 0o777)
  26. assert.Nil(t, err)
  27. dir := filepath.Join(pathx.MustTempDir(), "./testmodel")
  28. cacheDir := filepath.Join(dir, "cache")
  29. noCacheDir := filepath.Join(dir, "nocache")
  30. g, err := NewDefaultGenerator(cacheDir, &config.Config{
  31. NamingFormat: "GoZero",
  32. })
  33. assert.Nil(t, err)
  34. err = g.StartFromDDL(sqlFile, true, "go_zero")
  35. assert.Nil(t, err)
  36. assert.True(t, func() bool {
  37. _, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go"))
  38. return err == nil
  39. }())
  40. g, err = NewDefaultGenerator(noCacheDir, &config.Config{
  41. NamingFormat: "gozero",
  42. })
  43. assert.Nil(t, err)
  44. err = g.StartFromDDL(sqlFile, false, "go_zero")
  45. assert.Nil(t, err)
  46. assert.True(t, func() bool {
  47. _, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go"))
  48. return err == nil
  49. }())
  50. }
  51. func TestNamingModel(t *testing.T) {
  52. logx.Disable()
  53. _ = Clean()
  54. sqlFile := filepath.Join(pathx.MustTempDir(), "tmp.sql")
  55. err := ioutil.WriteFile(sqlFile, []byte(source), 0o777)
  56. assert.Nil(t, err)
  57. dir, _ := filepath.Abs("./testmodel")
  58. camelDir := filepath.Join(dir, "camel")
  59. snakeDir := filepath.Join(dir, "snake")
  60. defer func() {
  61. _ = os.RemoveAll(dir)
  62. }()
  63. g, err := NewDefaultGenerator(camelDir, &config.Config{
  64. NamingFormat: "GoZero",
  65. })
  66. assert.Nil(t, err)
  67. err = g.StartFromDDL(sqlFile, true, "go_zero")
  68. assert.Nil(t, err)
  69. assert.True(t, func() bool {
  70. _, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go"))
  71. return err == nil
  72. }())
  73. g, err = NewDefaultGenerator(snakeDir, &config.Config{
  74. NamingFormat: "go_zero",
  75. })
  76. assert.Nil(t, err)
  77. err = g.StartFromDDL(sqlFile, true, "go_zero")
  78. assert.Nil(t, err)
  79. assert.True(t, func() bool {
  80. _, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go"))
  81. return err == nil
  82. }())
  83. }
  84. func TestWrapWithRawString(t *testing.T) {
  85. assert.Equal(t, "``", wrapWithRawString("", false))
  86. assert.Equal(t, "``", wrapWithRawString("``", false))
  87. assert.Equal(t, "`a`", wrapWithRawString("a", false))
  88. assert.Equal(t, "a", wrapWithRawString("a", true))
  89. assert.Equal(t, "` `", wrapWithRawString(" ", false))
  90. }
  91. func TestFields(t *testing.T) {
  92. type Student struct {
  93. ID int64 `db:"id"`
  94. Name string `db:"name"`
  95. Age sql.NullInt64 `db:"age"`
  96. Score sql.NullFloat64 `db:"score"`
  97. CreateTime time.Time `db:"create_time"`
  98. UpdateTime sql.NullTime `db:"update_time"`
  99. }
  100. var (
  101. studentFieldNames = builderx.RawFieldNames(&Student{})
  102. studentRows = strings.Join(studentFieldNames, ",")
  103. studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
  104. studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
  105. )
  106. assert.Equal(t, []string{"`id`", "`name`", "`age`", "`score`", "`create_time`", "`update_time`"}, studentFieldNames)
  107. assert.Equal(t, "`id`,`name`,`age`,`score`,`create_time`,`update_time`", studentRows)
  108. assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet)
  109. assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder)
  110. }
  111. func Test_genPublicModel(t *testing.T) {
  112. var err error
  113. dir := pathx.MustTempDir()
  114. modelDir := path.Join(dir, "model")
  115. err = os.MkdirAll(modelDir, 0777)
  116. require.NoError(t, err)
  117. defer os.RemoveAll(dir)
  118. modelFilename := filepath.Join(modelDir, "foo.sql")
  119. err = ioutil.WriteFile(modelFilename, []byte(source), 0777)
  120. require.NoError(t, err)
  121. g, err := NewDefaultGenerator(modelDir, &config.Config{
  122. NamingFormat: config.DefaultFormat,
  123. })
  124. require.NoError(t, err)
  125. tables, err := parser.Parse(modelFilename, "")
  126. require.Equal(t, 1, len(tables))
  127. code, err := g.genModelCustom(*tables[0], false)
  128. assert.NoError(t, err)
  129. assert.True(t, strings.Contains(code, "package model"))
  130. assert.True(t, strings.Contains(code, "TestUserModel interface {\n\t\ttestUserModel\n\t}\n"))
  131. assert.True(t, strings.Contains(code, "customTestUserModel struct {\n\t\t*defaultTestUserModel\n\t}\n"))
  132. assert.True(t, strings.Contains(code, "func NewTestUserModel(conn sqlx.SqlConn) TestUserModel {"))
  133. }