parser_test.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package parser
  2. import (
  3. "io/ioutil"
  4. "path/filepath"
  5. "testing"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
  8. "github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
  9. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  10. )
  11. func TestParsePlainText(t *testing.T) {
  12. sqlFile := filepath.Join(ctlutil.MustTempDir(), "tmp.sql")
  13. err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777)
  14. assert.Nil(t, err)
  15. _, err = Parse(sqlFile, "go_zero")
  16. assert.NotNil(t, err)
  17. }
  18. func TestParseSelect(t *testing.T) {
  19. sqlFile := filepath.Join(ctlutil.MustTempDir(), "tmp.sql")
  20. err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777)
  21. assert.Nil(t, err)
  22. tables, err := Parse(sqlFile, "go_zero")
  23. assert.Nil(t, err)
  24. assert.Equal(t, 0, len(tables))
  25. }
  26. func TestParseCreateTable(t *testing.T) {
  27. sqlFile := filepath.Join(ctlutil.MustTempDir(), "tmp.sql")
  28. err := ioutil.WriteFile(sqlFile, []byte("CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机 号',\n `class` bigint NOT NULL comment '班级',\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n 名',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"), 0o777)
  29. assert.Nil(t, err)
  30. tables, err := Parse(sqlFile, "go_zero")
  31. assert.Equal(t, 1, len(tables))
  32. table := tables[0]
  33. assert.Nil(t, err)
  34. assert.Equal(t, "test_user", table.Name.Source())
  35. assert.Equal(t, "id", table.PrimaryKey.Name.Source())
  36. assert.Equal(t, true, table.ContainsTime())
  37. assert.Equal(t, 2, len(table.UniqueIndex))
  38. assert.True(t, func() bool {
  39. for _, e := range table.Fields {
  40. if e.Comment != util.TrimNewLine(e.Comment) {
  41. return false
  42. }
  43. }
  44. return true
  45. }())
  46. }
  47. func TestConvertColumn(t *testing.T) {
  48. t.Run("missingPrimaryKey", func(t *testing.T) {
  49. columnData := model.ColumnData{
  50. Db: "user",
  51. Table: "user",
  52. Columns: []*model.Column{
  53. {
  54. DbColumn: &model.DbColumn{
  55. Name: "id",
  56. DataType: "bigint",
  57. },
  58. },
  59. },
  60. }
  61. _, err := columnData.Convert()
  62. assert.NotNil(t, err)
  63. assert.Contains(t, err.Error(), "missing primary key")
  64. })
  65. t.Run("jointPrimaryKey", func(t *testing.T) {
  66. columnData := model.ColumnData{
  67. Db: "user",
  68. Table: "user",
  69. Columns: []*model.Column{
  70. {
  71. DbColumn: &model.DbColumn{
  72. Name: "id",
  73. DataType: "bigint",
  74. },
  75. Index: &model.DbIndex{
  76. IndexName: "PRIMARY",
  77. },
  78. },
  79. {
  80. DbColumn: &model.DbColumn{
  81. Name: "mobile",
  82. DataType: "varchar",
  83. Comment: "手机号",
  84. },
  85. Index: &model.DbIndex{
  86. IndexName: "PRIMARY",
  87. },
  88. },
  89. },
  90. }
  91. _, err := columnData.Convert()
  92. assert.NotNil(t, err)
  93. assert.Contains(t, err.Error(), "joint primary key is not supported")
  94. })
  95. t.Run("normal", func(t *testing.T) {
  96. columnData := model.ColumnData{
  97. Db: "user",
  98. Table: "user",
  99. Columns: []*model.Column{
  100. {
  101. DbColumn: &model.DbColumn{
  102. Name: "id",
  103. DataType: "bigint",
  104. Extra: "auto_increment",
  105. },
  106. Index: &model.DbIndex{
  107. IndexName: "PRIMARY",
  108. SeqInIndex: 1,
  109. },
  110. },
  111. {
  112. DbColumn: &model.DbColumn{
  113. Name: "mobile",
  114. DataType: "varchar",
  115. Comment: "手机号",
  116. },
  117. Index: &model.DbIndex{
  118. IndexName: "mobile_unique",
  119. SeqInIndex: 1,
  120. },
  121. },
  122. },
  123. }
  124. table, err := columnData.Convert()
  125. assert.Nil(t, err)
  126. assert.True(t, table.PrimaryKey.Index.IndexName == "PRIMARY" && table.PrimaryKey.Name == "id")
  127. for _, item := range table.Columns {
  128. if item.Name == "mobile" {
  129. assert.True(t, item.Index.NonUnique == 0)
  130. break
  131. }
  132. }
  133. })
  134. }