parser.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. package parser
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "sort"
  6. "strings"
  7. "github.com/tal-tech/go-zero/core/collection"
  8. "github.com/tal-tech/go-zero/tools/goctl/model/sql/converter"
  9. "github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
  10. "github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
  11. "github.com/tal-tech/go-zero/tools/goctl/util/console"
  12. "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
  13. "github.com/zeromicro/ddl-parser/parser"
  14. )
  15. const timeImport = "time.Time"
  16. type (
  17. // Table describes a mysql table
  18. Table struct {
  19. Name stringx.String
  20. Db stringx.String
  21. PrimaryKey Primary
  22. UniqueIndex map[string][]*Field
  23. Fields []*Field
  24. }
  25. // Primary describes a primary key
  26. Primary struct {
  27. Field
  28. AutoIncrement bool
  29. }
  30. // Field describes a table field
  31. Field struct {
  32. Name stringx.String
  33. DataType string
  34. Comment string
  35. SeqInIndex int
  36. OrdinalPosition int
  37. }
  38. // KeyType types alias of int
  39. KeyType int
  40. )
  41. // Parse parses ddl into golang structure
  42. func Parse(filename, database string) ([]*Table, error) {
  43. p := parser.NewParser()
  44. tables, err := p.From(filename)
  45. if err != nil {
  46. return nil, err
  47. }
  48. indexNameGen := func(column ...string) string {
  49. return strings.Join(column, "_")
  50. }
  51. prefix := filepath.Base(filename)
  52. var list []*Table
  53. for _, e := range tables {
  54. columns := e.Columns
  55. var (
  56. primaryColumnSet = collection.NewSet()
  57. primaryColumn string
  58. uniqueKeyMap = make(map[string][]string)
  59. normalKeyMap = make(map[string][]string)
  60. )
  61. for _, column := range columns {
  62. if column.Constraint != nil {
  63. if column.Constraint.Primary {
  64. primaryColumnSet.AddStr(column.Name)
  65. }
  66. if column.Constraint.Unique {
  67. indexName := indexNameGen(column.Name, "unique")
  68. uniqueKeyMap[indexName] = []string{column.Name}
  69. }
  70. if column.Constraint.Key {
  71. indexName := indexNameGen(column.Name, "idx")
  72. uniqueKeyMap[indexName] = []string{column.Name}
  73. }
  74. }
  75. }
  76. for _, e := range e.Constraints {
  77. if len(e.ColumnPrimaryKey) > 1 {
  78. return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
  79. }
  80. if len(e.ColumnPrimaryKey) == 1 {
  81. primaryColumn = e.ColumnPrimaryKey[0]
  82. primaryColumnSet.AddStr(e.ColumnPrimaryKey[0])
  83. }
  84. if len(e.ColumnUniqueKey) > 0 {
  85. list := append([]string(nil), e.ColumnUniqueKey...)
  86. list = append(list, "unique")
  87. indexName := indexNameGen(list...)
  88. uniqueKeyMap[indexName] = e.ColumnUniqueKey
  89. }
  90. }
  91. if primaryColumnSet.Count() > 1 {
  92. return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
  93. }
  94. primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
  95. if err != nil {
  96. return nil, err
  97. }
  98. var fields []*Field
  99. // sort
  100. for _, c := range columns {
  101. field, ok := fieldM[c.Name]
  102. if ok {
  103. fields = append(fields, field)
  104. }
  105. }
  106. var (
  107. uniqueIndex = make(map[string][]*Field)
  108. normalIndex = make(map[string][]*Field)
  109. )
  110. for indexName, each := range uniqueKeyMap {
  111. for _, columnName := range each {
  112. uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
  113. }
  114. }
  115. for indexName, each := range normalKeyMap {
  116. for _, columnName := range each {
  117. normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
  118. }
  119. }
  120. checkDuplicateUniqueIndex(uniqueIndex, e.Name)
  121. list = append(list, &Table{
  122. Name: stringx.From(e.Name),
  123. Db: stringx.From(database),
  124. PrimaryKey: primaryKey,
  125. UniqueIndex: uniqueIndex,
  126. Fields: fields,
  127. })
  128. }
  129. return list, nil
  130. }
  131. func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string) {
  132. log := console.NewColorConsole()
  133. uniqueSet := collection.NewSet()
  134. for k, i := range uniqueIndex {
  135. var list []string
  136. for _, e := range i {
  137. list = append(list, e.Name.Source())
  138. }
  139. joinRet := strings.Join(list, ",")
  140. if uniqueSet.Contains(joinRet) {
  141. log.Warning("table %s: duplicate unique index %s", tableName, joinRet)
  142. delete(uniqueIndex, k)
  143. continue
  144. }
  145. uniqueSet.AddStr(joinRet)
  146. }
  147. }
  148. func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, map[string]*Field, error) {
  149. var (
  150. primaryKey Primary
  151. fieldM = make(map[string]*Field)
  152. log = console.NewColorConsole()
  153. )
  154. for _, column := range columns {
  155. if column == nil {
  156. continue
  157. }
  158. var (
  159. comment string
  160. isDefaultNull bool
  161. )
  162. if column.Constraint != nil {
  163. comment = column.Constraint.Comment
  164. isDefaultNull = !column.Constraint.NotNull
  165. if !column.Constraint.NotNull && column.Constraint.HasDefaultValue {
  166. isDefaultNull = false
  167. }
  168. if column.Name == primaryColumn {
  169. isDefaultNull = false
  170. }
  171. }
  172. dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull)
  173. if err != nil {
  174. return Primary{}, nil, err
  175. }
  176. if column.Constraint != nil {
  177. if column.Name == primaryColumn {
  178. if !column.Constraint.AutoIncrement && dataType == "int64" {
  179. log.Warning("%s: The primary key is recommended to add constraint `AUTO_INCREMENT`", column.Name)
  180. }
  181. } else if column.Constraint.NotNull && !column.Constraint.HasDefaultValue {
  182. log.Warning("%s: The column is recommended to add constraint `DEFAULT`", column.Name)
  183. }
  184. }
  185. var field Field
  186. field.Name = stringx.From(column.Name)
  187. field.DataType = dataType
  188. field.Comment = util.TrimNewLine(comment)
  189. if field.Name.Source() == primaryColumn {
  190. primaryKey = Primary{
  191. Field: field,
  192. }
  193. if column.Constraint != nil {
  194. primaryKey.AutoIncrement = column.Constraint.AutoIncrement
  195. }
  196. }
  197. fieldM[field.Name.Source()] = &field
  198. }
  199. return primaryKey, fieldM, nil
  200. }
  201. // ContainsTime returns true if contains golang type time.Time
  202. func (t *Table) ContainsTime() bool {
  203. for _, item := range t.Fields {
  204. if item.DataType == timeImport {
  205. return true
  206. }
  207. }
  208. return false
  209. }
  210. // ConvertDataType converts mysql data type into golang data type
  211. func ConvertDataType(table *model.Table) (*Table, error) {
  212. isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
  213. primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull)
  214. if err != nil {
  215. return nil, err
  216. }
  217. var reply Table
  218. reply.UniqueIndex = map[string][]*Field{}
  219. reply.Name = stringx.From(table.Table)
  220. reply.Db = stringx.From(table.Db)
  221. seqInIndex := 0
  222. if table.PrimaryKey.Index != nil {
  223. seqInIndex = table.PrimaryKey.Index.SeqInIndex
  224. }
  225. reply.PrimaryKey = Primary{
  226. Field: Field{
  227. Name: stringx.From(table.PrimaryKey.Name),
  228. DataType: primaryDataType,
  229. Comment: table.PrimaryKey.Comment,
  230. SeqInIndex: seqInIndex,
  231. OrdinalPosition: table.PrimaryKey.OrdinalPosition,
  232. },
  233. AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
  234. }
  235. fieldM, err := getTableFields(table)
  236. if err != nil {
  237. return nil, err
  238. }
  239. for _, each := range fieldM {
  240. reply.Fields = append(reply.Fields, each)
  241. }
  242. sort.Slice(reply.Fields, func(i, j int) bool {
  243. return reply.Fields[i].OrdinalPosition < reply.Fields[j].OrdinalPosition
  244. })
  245. uniqueIndexSet := collection.NewSet()
  246. log := console.NewColorConsole()
  247. for indexName, each := range table.UniqueIndex {
  248. sort.Slice(each, func(i, j int) bool {
  249. if each[i].Index != nil {
  250. return each[i].Index.SeqInIndex < each[j].Index.SeqInIndex
  251. }
  252. return false
  253. })
  254. if len(each) == 1 {
  255. one := each[0]
  256. if one.Name == table.PrimaryKey.Name {
  257. log.Warning("table %s: duplicate unique index with primary key, %s", table.Table, one.Name)
  258. continue
  259. }
  260. }
  261. var list []*Field
  262. var uniqueJoin []string
  263. for _, c := range each {
  264. list = append(list, fieldM[c.Name])
  265. uniqueJoin = append(uniqueJoin, c.Name)
  266. }
  267. uniqueKey := strings.Join(uniqueJoin, ",")
  268. if uniqueIndexSet.Contains(uniqueKey) {
  269. log.Warning("table %s: duplicate unique index, %s", table.Table, uniqueKey)
  270. continue
  271. }
  272. uniqueIndexSet.AddStr(uniqueKey)
  273. reply.UniqueIndex[indexName] = list
  274. }
  275. return &reply, nil
  276. }
  277. func getTableFields(table *model.Table) (map[string]*Field, error) {
  278. fieldM := make(map[string]*Field)
  279. for _, each := range table.Columns {
  280. isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
  281. dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull)
  282. if err != nil {
  283. return nil, err
  284. }
  285. columnSeqInIndex := 0
  286. if each.Index != nil {
  287. columnSeqInIndex = each.Index.SeqInIndex
  288. }
  289. field := &Field{
  290. Name: stringx.From(each.Name),
  291. DataType: dt,
  292. Comment: each.Comment,
  293. SeqInIndex: columnSeqInIndex,
  294. OrdinalPosition: each.OrdinalPosition,
  295. }
  296. fieldM[each.Name] = field
  297. }
  298. return fieldM, nil
  299. }