parser.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. package parser
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "sort"
  6. "strings"
  7. "github.com/tal-tech/go-zero/core/collection"
  8. "github.com/tal-tech/go-zero/tools/goctl/model/sql/converter"
  9. "github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
  10. "github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
  11. "github.com/tal-tech/go-zero/tools/goctl/util/console"
  12. "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
  13. "github.com/zeromicro/ddl-parser/parser"
  14. )
  15. const timeImport = "time.Time"
  16. type (
  17. // Table describes a mysql table
  18. Table struct {
  19. Name stringx.String
  20. PrimaryKey Primary
  21. UniqueIndex map[string][]*Field
  22. Fields []*Field
  23. }
  24. // Primary describes a primary key
  25. Primary struct {
  26. Field
  27. AutoIncrement bool
  28. }
  29. // Field describes a table field
  30. Field struct {
  31. Name stringx.String
  32. DataType string
  33. Comment string
  34. SeqInIndex int
  35. OrdinalPosition int
  36. }
  37. // KeyType types alias of int
  38. KeyType int
  39. )
  40. // Parse parses ddl into golang structure
  41. func Parse(filename string) ([]*Table, error) {
  42. p := parser.NewParser()
  43. tables, err := p.From(filename)
  44. if err != nil {
  45. return nil, err
  46. }
  47. indexNameGen := func(column ...string) string {
  48. return strings.Join(column, "_")
  49. }
  50. prefix := filepath.Base(filename)
  51. var list []*Table
  52. for _, e := range tables {
  53. columns := e.Columns
  54. var (
  55. primaryColumnSet = collection.NewSet()
  56. primaryColumn string
  57. uniqueKeyMap = make(map[string][]string)
  58. normalKeyMap = make(map[string][]string)
  59. )
  60. for _, column := range columns {
  61. if column.Constraint != nil {
  62. if column.Constraint.Primary {
  63. primaryColumnSet.AddStr(column.Name)
  64. }
  65. if column.Constraint.Unique {
  66. indexName := indexNameGen(column.Name, "unique")
  67. uniqueKeyMap[indexName] = []string{column.Name}
  68. }
  69. if column.Constraint.Key {
  70. indexName := indexNameGen(column.Name, "idx")
  71. uniqueKeyMap[indexName] = []string{column.Name}
  72. }
  73. }
  74. }
  75. for _, e := range e.Constraints {
  76. if len(e.ColumnPrimaryKey) > 1 {
  77. return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
  78. }
  79. if len(e.ColumnPrimaryKey) == 1 {
  80. primaryColumn = e.ColumnPrimaryKey[0]
  81. primaryColumnSet.AddStr(e.ColumnPrimaryKey[0])
  82. }
  83. if len(e.ColumnUniqueKey) > 0 {
  84. list := append([]string(nil), e.ColumnUniqueKey...)
  85. list = append(list, "unique")
  86. indexName := indexNameGen(list...)
  87. uniqueKeyMap[indexName] = e.ColumnUniqueKey
  88. }
  89. }
  90. if primaryColumnSet.Count() > 1 {
  91. return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
  92. }
  93. primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
  94. if err != nil {
  95. return nil, err
  96. }
  97. var fields []*Field
  98. // sort
  99. for _, c := range columns {
  100. field, ok := fieldM[c.Name]
  101. if ok {
  102. fields = append(fields, field)
  103. }
  104. }
  105. var (
  106. uniqueIndex = make(map[string][]*Field)
  107. normalIndex = make(map[string][]*Field)
  108. )
  109. for indexName, each := range uniqueKeyMap {
  110. for _, columnName := range each {
  111. uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
  112. }
  113. }
  114. for indexName, each := range normalKeyMap {
  115. for _, columnName := range each {
  116. normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
  117. }
  118. }
  119. checkDuplicateUniqueIndex(uniqueIndex, e.Name)
  120. list = append(list, &Table{
  121. Name: stringx.From(e.Name),
  122. PrimaryKey: primaryKey,
  123. UniqueIndex: uniqueIndex,
  124. Fields: fields,
  125. })
  126. }
  127. return list, nil
  128. }
  129. func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string) {
  130. log := console.NewColorConsole()
  131. uniqueSet := collection.NewSet()
  132. for k, i := range uniqueIndex {
  133. var list []string
  134. for _, e := range i {
  135. list = append(list, e.Name.Source())
  136. }
  137. joinRet := strings.Join(list, ",")
  138. if uniqueSet.Contains(joinRet) {
  139. log.Warning("table %s: duplicate unique index %s", tableName, joinRet)
  140. delete(uniqueIndex, k)
  141. continue
  142. }
  143. uniqueSet.AddStr(joinRet)
  144. }
  145. }
  146. func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, map[string]*Field, error) {
  147. var (
  148. primaryKey Primary
  149. fieldM = make(map[string]*Field)
  150. )
  151. for _, column := range columns {
  152. if column == nil {
  153. continue
  154. }
  155. var (
  156. comment string
  157. isDefaultNull bool
  158. )
  159. if column.Constraint != nil {
  160. comment = column.Constraint.Comment
  161. isDefaultNull = !column.Constraint.HasDefaultValue
  162. if column.Name == primaryColumn && column.Constraint.AutoIncrement {
  163. isDefaultNull = false
  164. }
  165. }
  166. dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull)
  167. if err != nil {
  168. return Primary{}, nil, err
  169. }
  170. var field Field
  171. field.Name = stringx.From(column.Name)
  172. field.DataType = dataType
  173. field.Comment = util.TrimNewLine(comment)
  174. if field.Name.Source() == primaryColumn {
  175. primaryKey = Primary{
  176. Field: field,
  177. }
  178. if column.Constraint != nil {
  179. primaryKey.AutoIncrement = column.Constraint.AutoIncrement
  180. }
  181. }
  182. fieldM[field.Name.Source()] = &field
  183. }
  184. return primaryKey, fieldM, nil
  185. }
  186. // ContainsTime returns true if contains golang type time.Time
  187. func (t *Table) ContainsTime() bool {
  188. for _, item := range t.Fields {
  189. if item.DataType == timeImport {
  190. return true
  191. }
  192. }
  193. return false
  194. }
  195. // ConvertDataType converts mysql data type into golang data type
  196. func ConvertDataType(table *model.Table) (*Table, error) {
  197. isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
  198. primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull)
  199. if err != nil {
  200. return nil, err
  201. }
  202. var reply Table
  203. reply.UniqueIndex = map[string][]*Field{}
  204. reply.Name = stringx.From(table.Table)
  205. seqInIndex := 0
  206. if table.PrimaryKey.Index != nil {
  207. seqInIndex = table.PrimaryKey.Index.SeqInIndex
  208. }
  209. reply.PrimaryKey = Primary{
  210. Field: Field{
  211. Name: stringx.From(table.PrimaryKey.Name),
  212. DataType: primaryDataType,
  213. Comment: table.PrimaryKey.Comment,
  214. SeqInIndex: seqInIndex,
  215. OrdinalPosition: table.PrimaryKey.OrdinalPosition,
  216. },
  217. AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
  218. }
  219. fieldM, err := getTableFields(table)
  220. if err != nil {
  221. return nil, err
  222. }
  223. for _, each := range fieldM {
  224. reply.Fields = append(reply.Fields, each)
  225. }
  226. sort.Slice(reply.Fields, func(i, j int) bool {
  227. return reply.Fields[i].OrdinalPosition < reply.Fields[j].OrdinalPosition
  228. })
  229. uniqueIndexSet := collection.NewSet()
  230. log := console.NewColorConsole()
  231. for indexName, each := range table.UniqueIndex {
  232. sort.Slice(each, func(i, j int) bool {
  233. if each[i].Index != nil {
  234. return each[i].Index.SeqInIndex < each[j].Index.SeqInIndex
  235. }
  236. return false
  237. })
  238. if len(each) == 1 {
  239. one := each[0]
  240. if one.Name == table.PrimaryKey.Name {
  241. log.Warning("table %s: duplicate unique index with primary key, %s", table.Table, one.Name)
  242. continue
  243. }
  244. }
  245. var list []*Field
  246. var uniqueJoin []string
  247. for _, c := range each {
  248. list = append(list, fieldM[c.Name])
  249. uniqueJoin = append(uniqueJoin, c.Name)
  250. }
  251. uniqueKey := strings.Join(uniqueJoin, ",")
  252. if uniqueIndexSet.Contains(uniqueKey) {
  253. log.Warning("table %s: duplicate unique index, %s", table.Table, uniqueKey)
  254. continue
  255. }
  256. uniqueIndexSet.AddStr(uniqueKey)
  257. reply.UniqueIndex[indexName] = list
  258. }
  259. return &reply, nil
  260. }
  261. func getTableFields(table *model.Table) (map[string]*Field, error) {
  262. fieldM := make(map[string]*Field)
  263. for _, each := range table.Columns {
  264. isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
  265. dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull)
  266. if err != nil {
  267. return nil, err
  268. }
  269. columnSeqInIndex := 0
  270. if each.Index != nil {
  271. columnSeqInIndex = each.Index.SeqInIndex
  272. }
  273. field := &Field{
  274. Name: stringx.From(each.Name),
  275. DataType: dt,
  276. Comment: each.Comment,
  277. SeqInIndex: columnSeqInIndex,
  278. OrdinalPosition: each.OrdinalPosition,
  279. }
  280. fieldM[each.Name] = field
  281. }
  282. return fieldM, nil
  283. }