parser.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package parser
  2. import (
  3. "fmt"
  4. "github.com/tal-tech/go-zero/tools/goctl/model/sql/converter"
  5. "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
  6. "github.com/xwb1989/sqlparser"
  7. )
  8. const (
  9. none = iota
  10. primary
  11. unique
  12. normal
  13. spatial
  14. )
  15. const timeImport = "time.Time"
  16. type (
  17. Table struct {
  18. Name stringx.String
  19. PrimaryKey Primary
  20. Fields []Field
  21. }
  22. Primary struct {
  23. Field
  24. AutoIncrement bool
  25. }
  26. Field struct {
  27. Name stringx.String
  28. DataBaseType string
  29. DataType string
  30. IsKey bool
  31. IsPrimaryKey bool
  32. Comment string
  33. }
  34. KeyType int
  35. )
  36. func Parse(ddl string) (*Table, error) {
  37. stmt, err := sqlparser.ParseStrictDDL(ddl)
  38. if err != nil {
  39. return nil, err
  40. }
  41. ddlStmt, ok := stmt.(*sqlparser.DDL)
  42. if !ok {
  43. return nil, unSupportDDL
  44. }
  45. action := ddlStmt.Action
  46. if action != sqlparser.CreateStr {
  47. return nil, fmt.Errorf("expected [CREATE] action,but found: %s", action)
  48. }
  49. tableName := ddlStmt.NewName.Name.String()
  50. tableSpec := ddlStmt.TableSpec
  51. if tableSpec == nil {
  52. return nil, tableBodyIsNotFound
  53. }
  54. columns := tableSpec.Columns
  55. indexes := tableSpec.Indexes
  56. keyMap := make(map[string]KeyType)
  57. for _, index := range indexes {
  58. info := index.Info
  59. if info == nil {
  60. continue
  61. }
  62. if info.Primary {
  63. if len(index.Columns) > 1 {
  64. return nil, errPrimaryKey
  65. }
  66. keyMap[index.Columns[0].Column.String()] = primary
  67. continue
  68. }
  69. // can optimize
  70. if len(index.Columns) > 1 {
  71. continue
  72. }
  73. column := index.Columns[0]
  74. columnName := column.Column.String()
  75. camelColumnName := stringx.From(columnName).ToCamel()
  76. // by default, createTime|updateTime findOne is not used.
  77. if camelColumnName == "CreateTime" || camelColumnName == "UpdateTime" {
  78. continue
  79. }
  80. if info.Unique {
  81. keyMap[columnName] = unique
  82. } else if info.Spatial {
  83. keyMap[columnName] = spatial
  84. } else {
  85. keyMap[columnName] = normal
  86. }
  87. }
  88. var fields []Field
  89. var primaryKey Primary
  90. for _, column := range columns {
  91. if column == nil {
  92. continue
  93. }
  94. var comment string
  95. if column.Type.Comment != nil {
  96. comment = string(column.Type.Comment.Val)
  97. }
  98. dataType, err := converter.ConvertDataType(column.Type.Type)
  99. if err != nil {
  100. return nil, err
  101. }
  102. var field Field
  103. field.Name = stringx.From(column.Name.String())
  104. field.DataBaseType = column.Type.Type
  105. field.DataType = dataType
  106. field.Comment = comment
  107. key, ok := keyMap[column.Name.String()]
  108. if ok {
  109. field.IsKey = true
  110. field.IsPrimaryKey = key == primary
  111. if field.IsPrimaryKey {
  112. primaryKey.Field = field
  113. if column.Type.Autoincrement {
  114. primaryKey.AutoIncrement = true
  115. }
  116. }
  117. }
  118. fields = append(fields, field)
  119. }
  120. return &Table{
  121. Name: stringx.From(tableName),
  122. PrimaryKey: primaryKey,
  123. Fields: fields,
  124. }, nil
  125. }
  126. func (t *Table) ContainsTime() bool {
  127. for _, item := range t.Fields {
  128. if item.DataType == timeImport {
  129. return true
  130. }
  131. }
  132. return false
  133. }