parser_test.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package parser
  2. import (
  3. "io/ioutil"
  4. "path/filepath"
  5. "testing"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
  8. "github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
  9. )
  10. func TestParsePlainText(t *testing.T) {
  11. sqlFile := filepath.Join(t.TempDir(), "tmp.sql")
  12. err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777)
  13. assert.Nil(t, err)
  14. _, err = Parse(sqlFile)
  15. assert.NotNil(t, err)
  16. }
  17. func TestParseSelect(t *testing.T) {
  18. sqlFile := filepath.Join(t.TempDir(), "tmp.sql")
  19. err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777)
  20. assert.Nil(t, err)
  21. tables, err := Parse(sqlFile)
  22. assert.Nil(t, err)
  23. assert.Equal(t, 0, len(tables))
  24. }
  25. func TestParseCreateTable(t *testing.T) {
  26. sqlFile := filepath.Join(t.TempDir(), "tmp.sql")
  27. err := ioutil.WriteFile(sqlFile, []byte("CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机 号',\n `class` bigint NOT NULL comment '班级',\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n 名',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"), 0o777)
  28. assert.Nil(t, err)
  29. tables, err := Parse(sqlFile)
  30. assert.Equal(t, 1, len(tables))
  31. table := tables[0]
  32. assert.Nil(t, err)
  33. assert.Equal(t, "test_user", table.Name.Source())
  34. assert.Equal(t, "id", table.PrimaryKey.Name.Source())
  35. assert.Equal(t, true, table.ContainsTime())
  36. assert.Equal(t, 2, len(table.UniqueIndex))
  37. assert.True(t, func() bool {
  38. for _, e := range table.Fields {
  39. if e.Comment != util.TrimNewLine(e.Comment) {
  40. return false
  41. }
  42. }
  43. return true
  44. }())
  45. }
  46. func TestConvertColumn(t *testing.T) {
  47. t.Run("missingPrimaryKey", func(t *testing.T) {
  48. columnData := model.ColumnData{
  49. Db: "user",
  50. Table: "user",
  51. Columns: []*model.Column{
  52. {
  53. DbColumn: &model.DbColumn{
  54. Name: "id",
  55. DataType: "bigint",
  56. },
  57. },
  58. },
  59. }
  60. _, err := columnData.Convert()
  61. assert.NotNil(t, err)
  62. assert.Contains(t, err.Error(), "missing primary key")
  63. })
  64. t.Run("jointPrimaryKey", func(t *testing.T) {
  65. columnData := model.ColumnData{
  66. Db: "user",
  67. Table: "user",
  68. Columns: []*model.Column{
  69. {
  70. DbColumn: &model.DbColumn{
  71. Name: "id",
  72. DataType: "bigint",
  73. },
  74. Index: &model.DbIndex{
  75. IndexName: "PRIMARY",
  76. },
  77. },
  78. {
  79. DbColumn: &model.DbColumn{
  80. Name: "mobile",
  81. DataType: "varchar",
  82. Comment: "手机号",
  83. },
  84. Index: &model.DbIndex{
  85. IndexName: "PRIMARY",
  86. },
  87. },
  88. },
  89. }
  90. _, err := columnData.Convert()
  91. assert.NotNil(t, err)
  92. assert.Contains(t, err.Error(), "joint primary key is not supported")
  93. })
  94. t.Run("normal", func(t *testing.T) {
  95. columnData := model.ColumnData{
  96. Db: "user",
  97. Table: "user",
  98. Columns: []*model.Column{
  99. {
  100. DbColumn: &model.DbColumn{
  101. Name: "id",
  102. DataType: "bigint",
  103. Extra: "auto_increment",
  104. },
  105. Index: &model.DbIndex{
  106. IndexName: "PRIMARY",
  107. SeqInIndex: 1,
  108. },
  109. },
  110. {
  111. DbColumn: &model.DbColumn{
  112. Name: "mobile",
  113. DataType: "varchar",
  114. Comment: "手机号",
  115. },
  116. Index: &model.DbIndex{
  117. IndexName: "mobile_unique",
  118. SeqInIndex: 1,
  119. },
  120. },
  121. },
  122. }
  123. table, err := columnData.Convert()
  124. assert.Nil(t, err)
  125. assert.True(t, table.PrimaryKey.Index.IndexName == "PRIMARY" && table.PrimaryKey.Name == "id")
  126. for _, item := range table.Columns {
  127. if item.Name == "mobile" {
  128. assert.True(t, item.Index.NonUnique == 0)
  129. break
  130. }
  131. }
  132. })
  133. }