123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385 |
- package parser
- import (
- "fmt"
- "path/filepath"
- "sort"
- "strings"
- "github.com/zeromicro/ddl-parser/parser"
- "github.com/zeromicro/go-zero/core/collection"
- "github.com/zeromicro/go-zero/tools/goctl/model/sql/converter"
- "github.com/zeromicro/go-zero/tools/goctl/model/sql/model"
- "github.com/zeromicro/go-zero/tools/goctl/model/sql/util"
- "github.com/zeromicro/go-zero/tools/goctl/util/console"
- "github.com/zeromicro/go-zero/tools/goctl/util/stringx"
- )
- const timeImport = "time.Time"
- type (
- // Table describes a mysql table
- Table struct {
- Name stringx.String
- Db stringx.String
- PrimaryKey Primary
- UniqueIndex map[string][]*Field
- Fields []*Field
- ContainsPQ bool
- }
- // Primary describes a primary key
- Primary struct {
- Field
- AutoIncrement bool
- }
- // Field describes a table field
- Field struct {
- NameOriginal string
- Name stringx.String
- DataType string
- Comment string
- SeqInIndex int
- OrdinalPosition int
- ContainsPQ bool
- }
- // KeyType types alias of int
- KeyType int
- )
- func parseNameOriginal(ts []*parser.Table) (nameOriginals [][]string) {
- var columns []string
- for _, t := range ts {
- columns = []string{}
- for _, c := range t.Columns {
- columns = append(columns, c.Name)
- }
- nameOriginals = append(nameOriginals, columns)
- }
- return
- }
- // Parse parses ddl into golang structure
- func Parse(filename, database string, strict bool) ([]*Table, error) {
- p := parser.NewParser()
- tables, err := p.From(filename)
- if err != nil {
- return nil, err
- }
- nameOriginals := parseNameOriginal(tables)
- indexNameGen := func(column ...string) string {
- return strings.Join(column, "_")
- }
- prefix := filepath.Base(filename)
- var list []*Table
- for indexTable, e := range tables {
- var (
- primaryColumn string
- primaryColumnSet = collection.NewSet()
- uniqueKeyMap = make(map[string][]string)
- // Unused local variable
- // normalKeyMap = make(map[string][]string)
- columns = e.Columns
- )
- for _, column := range columns {
- if column.Constraint != nil {
- if column.Constraint.Primary {
- primaryColumnSet.AddStr(column.Name)
- }
- if column.Constraint.Unique {
- indexName := indexNameGen(column.Name, "unique")
- uniqueKeyMap[indexName] = []string{column.Name}
- }
- if column.Constraint.Key {
- indexName := indexNameGen(column.Name, "idx")
- uniqueKeyMap[indexName] = []string{column.Name}
- }
- }
- }
- for _, e := range e.Constraints {
- if len(e.ColumnPrimaryKey) > 1 {
- return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
- }
- if len(e.ColumnPrimaryKey) == 1 {
- primaryColumn = e.ColumnPrimaryKey[0]
- primaryColumnSet.AddStr(e.ColumnPrimaryKey[0])
- }
- if len(e.ColumnUniqueKey) > 0 {
- list := append([]string(nil), e.ColumnUniqueKey...)
- list = append(list, "unique")
- indexName := indexNameGen(list...)
- uniqueKeyMap[indexName] = e.ColumnUniqueKey
- }
- }
- if primaryColumnSet.Count() > 1 {
- return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
- }
- primaryKey, fieldM, err := convertColumns(columns, primaryColumn, strict)
- if err != nil {
- return nil, err
- }
- var fields []*Field
- // sort
- for indexColumn, c := range columns {
- field, ok := fieldM[c.Name]
- if ok {
- field.NameOriginal = nameOriginals[indexTable][indexColumn]
- fields = append(fields, field)
- }
- }
- var (
- uniqueIndex = make(map[string][]*Field)
- // Unused local variable
- // normalIndex = make(map[string][]*Field)
- )
- for indexName, each := range uniqueKeyMap {
- for _, columnName := range each {
- // Prevent a crash if there is a unique key constraint with a nil field.
- if fieldM[columnName] == nil {
- return nil, fmt.Errorf("table %s: unique key with error column name[%s]", e.Name, columnName)
- }
- uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
- }
- }
- // Unused local variable
- // for indexName, each := range normalKeyMap {
- // for _, columnName := range each {
- // normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
- // }
- // }
- checkDuplicateUniqueIndex(uniqueIndex, e.Name)
- list = append(list, &Table{
- Name: stringx.From(e.Name),
- Db: stringx.From(database),
- PrimaryKey: primaryKey,
- UniqueIndex: uniqueIndex,
- Fields: fields,
- })
- }
- return list, nil
- }
- func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string) {
- log := console.NewColorConsole()
- uniqueSet := collection.NewSet()
- for k, i := range uniqueIndex {
- var list []string
- for _, e := range i {
- list = append(list, e.Name.Source())
- }
- joinRet := strings.Join(list, ",")
- if uniqueSet.Contains(joinRet) {
- log.Warning("[checkDuplicateUniqueIndex]: table %s: duplicate unique index %s", tableName, joinRet)
- delete(uniqueIndex, k)
- continue
- }
- uniqueSet.AddStr(joinRet)
- }
- }
- func convertColumns(columns []*parser.Column, primaryColumn string, strict bool) (Primary, map[string]*Field, error) {
- var (
- primaryKey Primary
- fieldM = make(map[string]*Field)
- log = console.NewColorConsole()
- )
- for _, column := range columns {
- if column == nil {
- continue
- }
- var (
- comment string
- isDefaultNull bool
- )
- if column.Constraint != nil {
- comment = column.Constraint.Comment
- isDefaultNull = !column.Constraint.NotNull
- if !column.Constraint.NotNull && column.Constraint.HasDefaultValue {
- isDefaultNull = false
- }
- if column.Name == primaryColumn {
- isDefaultNull = false
- }
- }
- dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned(), strict)
- if err != nil {
- return Primary{}, nil, err
- }
- if column.Constraint != nil {
- if column.Name == primaryColumn {
- if !column.Constraint.AutoIncrement && dataType == "int64" {
- log.Warning("[convertColumns]: The primary key %q is recommended to add constraint `AUTO_INCREMENT`", column.Name)
- }
- } else if column.Constraint.NotNull && !column.Constraint.HasDefaultValue {
- log.Warning("[convertColumns]: The column %q is recommended to add constraint `DEFAULT`", column.Name)
- }
- }
- var field Field
- field.Name = stringx.From(column.Name)
- field.DataType = dataType
- field.Comment = util.TrimNewLine(comment)
- if field.Name.Source() == primaryColumn {
- primaryKey = Primary{
- Field: field,
- }
- if column.Constraint != nil {
- primaryKey.AutoIncrement = column.Constraint.AutoIncrement
- }
- }
- fieldM[field.Name.Source()] = &field
- }
- return primaryKey, fieldM, nil
- }
- // ContainsTime returns true if contains golang type time.Time
- func (t *Table) ContainsTime() bool {
- for _, item := range t.Fields {
- if item.DataType == timeImport {
- return true
- }
- }
- return false
- }
- // ConvertDataType converts mysql data type into golang data type
- func ConvertDataType(table *model.Table, strict bool) (*Table, error) {
- isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
- isPrimaryUnsigned := strings.Contains(table.PrimaryKey.DbColumn.ColumnType, "unsigned")
- primaryDataType, containsPQ, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned, strict)
- if err != nil {
- return nil, err
- }
- var reply Table
- reply.ContainsPQ = containsPQ
- reply.UniqueIndex = map[string][]*Field{}
- reply.Name = stringx.From(table.Table)
- reply.Db = stringx.From(table.Db)
- seqInIndex := 0
- if table.PrimaryKey.Index != nil {
- seqInIndex = table.PrimaryKey.Index.SeqInIndex
- }
- reply.PrimaryKey = Primary{
- Field: Field{
- Name: stringx.From(table.PrimaryKey.Name),
- DataType: primaryDataType,
- Comment: table.PrimaryKey.Comment,
- SeqInIndex: seqInIndex,
- OrdinalPosition: table.PrimaryKey.OrdinalPosition,
- },
- AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
- }
- fieldM, err := getTableFields(table, strict)
- if err != nil {
- return nil, err
- }
- for _, each := range fieldM {
- if each.ContainsPQ {
- reply.ContainsPQ = true
- }
- reply.Fields = append(reply.Fields, each)
- }
- sort.Slice(reply.Fields, func(i, j int) bool {
- return reply.Fields[i].OrdinalPosition < reply.Fields[j].OrdinalPosition
- })
- uniqueIndexSet := collection.NewSet()
- log := console.NewColorConsole()
- for indexName, each := range table.UniqueIndex {
- sort.Slice(each, func(i, j int) bool {
- if each[i].Index != nil {
- return each[i].Index.SeqInIndex < each[j].Index.SeqInIndex
- }
- return false
- })
- if len(each) == 1 {
- one := each[0]
- if one.Name == table.PrimaryKey.Name {
- log.Warning("[ConvertDataType]: table %q, duplicate unique index with primary key: %q", table.Table, one.Name)
- continue
- }
- }
- var list []*Field
- var uniqueJoin []string
- for _, c := range each {
- list = append(list, fieldM[c.Name])
- uniqueJoin = append(uniqueJoin, c.Name)
- }
- uniqueKey := strings.Join(uniqueJoin, ",")
- if uniqueIndexSet.Contains(uniqueKey) {
- log.Warning("[ConvertDataType]: table %q, duplicate unique index %q", table.Table, uniqueKey)
- continue
- }
- uniqueIndexSet.AddStr(uniqueKey)
- reply.UniqueIndex[indexName] = list
- }
- return &reply, nil
- }
- func getTableFields(table *model.Table, strict bool) (map[string]*Field, error) {
- fieldM := make(map[string]*Field)
- for _, each := range table.Columns {
- isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
- isPrimaryUnsigned := strings.Contains(each.ColumnType, "unsigned")
- dt, containsPQ, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned, strict)
- if err != nil {
- return nil, err
- }
- columnSeqInIndex := 0
- if each.Index != nil {
- columnSeqInIndex = each.Index.SeqInIndex
- }
- field := &Field{
- NameOriginal: each.Name,
- Name: stringx.From(each.Name),
- DataType: dt,
- Comment: each.Comment,
- SeqInIndex: columnSeqInIndex,
- OrdinalPosition: each.OrdinalPosition,
- ContainsPQ: containsPQ,
- }
- fieldM[each.Name] = field
- }
- return fieldM, nil
- }
|