|
@@ -2,6 +2,7 @@ package parser
|
|
|
|
|
|
import (
|
|
import (
|
|
"fmt"
|
|
"fmt"
|
|
|
|
+ "path/filepath"
|
|
"sort"
|
|
"sort"
|
|
"strings"
|
|
"strings"
|
|
|
|
|
|
@@ -11,7 +12,7 @@ import (
|
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
|
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
|
|
"github.com/tal-tech/go-zero/tools/goctl/util/console"
|
|
"github.com/tal-tech/go-zero/tools/goctl/util/console"
|
|
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
|
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
|
- "github.com/xwb1989/sqlparser"
|
|
|
|
|
|
+ "github.com/zeromicro/ddl-parser/parser"
|
|
)
|
|
)
|
|
|
|
|
|
const timeImport = "time.Time"
|
|
const timeImport = "time.Time"
|
|
@@ -22,7 +23,6 @@ type (
|
|
Name stringx.String
|
|
Name stringx.String
|
|
PrimaryKey Primary
|
|
PrimaryKey Primary
|
|
UniqueIndex map[string][]*Field
|
|
UniqueIndex map[string][]*Field
|
|
- NormalIndex map[string][]*Field
|
|
|
|
Fields []*Field
|
|
Fields []*Field
|
|
}
|
|
}
|
|
|
|
|
|
@@ -35,7 +35,6 @@ type (
|
|
// Field describes a table field
|
|
// Field describes a table field
|
|
Field struct {
|
|
Field struct {
|
|
Name stringx.String
|
|
Name stringx.String
|
|
- DataBaseType string
|
|
|
|
DataType string
|
|
DataType string
|
|
Comment string
|
|
Comment string
|
|
SeqInIndex int
|
|
SeqInIndex int
|
|
@@ -47,73 +46,115 @@ type (
|
|
)
|
|
)
|
|
|
|
|
|
// Parse parses ddl into golang structure
|
|
// Parse parses ddl into golang structure
|
|
-func Parse(ddl string) (*Table, error) {
|
|
|
|
- stmt, err := sqlparser.ParseStrictDDL(ddl)
|
|
|
|
|
|
+func Parse(filename string) ([]*Table, error) {
|
|
|
|
+ p := parser.NewParser()
|
|
|
|
+ tables, err := p.From(filename)
|
|
if err != nil {
|
|
if err != nil {
|
|
return nil, err
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
|
|
- ddlStmt, ok := stmt.(*sqlparser.DDL)
|
|
|
|
- if !ok {
|
|
|
|
- return nil, errUnsupportDDL
|
|
|
|
|
|
+ indexNameGen := func(column ...string) string {
|
|
|
|
+ return strings.Join(column, "_")
|
|
}
|
|
}
|
|
|
|
|
|
- action := ddlStmt.Action
|
|
|
|
- if action != sqlparser.CreateStr {
|
|
|
|
- return nil, fmt.Errorf("expected [CREATE] action,but found: %s", action)
|
|
|
|
- }
|
|
|
|
|
|
+ prefix := filepath.Base(filename)
|
|
|
|
+ var list []*Table
|
|
|
|
+ for _, e := range tables {
|
|
|
|
+ columns := e.Columns
|
|
|
|
|
|
- tableName := ddlStmt.NewName.Name.String()
|
|
|
|
- tableSpec := ddlStmt.TableSpec
|
|
|
|
- if tableSpec == nil {
|
|
|
|
- return nil, errTableBodyNotFound
|
|
|
|
- }
|
|
|
|
|
|
+ var (
|
|
|
|
+ primaryColumnSet = collection.NewSet()
|
|
|
|
|
|
- columns := tableSpec.Columns
|
|
|
|
- indexes := tableSpec.Indexes
|
|
|
|
- primaryColumn, uniqueKeyMap, normalKeyMap, err := convertIndexes(indexes)
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, err
|
|
|
|
- }
|
|
|
|
|
|
+ primaryColumn string
|
|
|
|
+ uniqueKeyMap = make(map[string][]string)
|
|
|
|
+ normalKeyMap = make(map[string][]string)
|
|
|
|
+ )
|
|
|
|
|
|
- primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, err
|
|
|
|
- }
|
|
|
|
|
|
+ for _, column := range columns {
|
|
|
|
+ if column.Constraint != nil {
|
|
|
|
+ if column.Constraint.Primary {
|
|
|
|
+ primaryColumnSet.AddStr(column.Name)
|
|
|
|
+ }
|
|
|
|
|
|
- var fields []*Field
|
|
|
|
- for _, e := range fieldM {
|
|
|
|
- fields = append(fields, e)
|
|
|
|
- }
|
|
|
|
|
|
+ if column.Constraint.Unique {
|
|
|
|
+ indexName := indexNameGen(column.Name, "unique")
|
|
|
|
+ uniqueKeyMap[indexName] = []string{column.Name}
|
|
|
|
+ }
|
|
|
|
|
|
- var (
|
|
|
|
- uniqueIndex = make(map[string][]*Field)
|
|
|
|
- normalIndex = make(map[string][]*Field)
|
|
|
|
- )
|
|
|
|
|
|
+ if column.Constraint.Key {
|
|
|
|
+ indexName := indexNameGen(column.Name, "idx")
|
|
|
|
+ uniqueKeyMap[indexName] = []string{column.Name}
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
|
|
- for indexName, each := range uniqueKeyMap {
|
|
|
|
- for _, columnName := range each {
|
|
|
|
- uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
|
|
|
|
|
|
+ for _, e := range e.Constraints {
|
|
|
|
+ if len(e.ColumnPrimaryKey) > 1 {
|
|
|
|
+ return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if len(e.ColumnPrimaryKey) == 1 {
|
|
|
|
+ primaryColumn = e.ColumnPrimaryKey[0]
|
|
|
|
+ primaryColumnSet.AddStr(e.ColumnPrimaryKey[0])
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if len(e.ColumnUniqueKey) > 0 {
|
|
|
|
+ list := append([]string(nil), e.ColumnUniqueKey...)
|
|
|
|
+ list = append(list, "unique")
|
|
|
|
+ indexName := indexNameGen(list...)
|
|
|
|
+ uniqueKeyMap[indexName] = e.ColumnUniqueKey
|
|
|
|
+ }
|
|
}
|
|
}
|
|
- }
|
|
|
|
|
|
|
|
- for indexName, each := range normalKeyMap {
|
|
|
|
- for _, columnName := range each {
|
|
|
|
- normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
|
|
|
|
|
|
+ if primaryColumnSet.Count() > 1 {
|
|
|
|
+ return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, err
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ var fields []*Field
|
|
|
|
+ // sort
|
|
|
|
+ for _, c := range columns {
|
|
|
|
+ field, ok := fieldM[c.Name]
|
|
|
|
+ if ok {
|
|
|
|
+ fields = append(fields, field)
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ var (
|
|
|
|
+ uniqueIndex = make(map[string][]*Field)
|
|
|
|
+ normalIndex = make(map[string][]*Field)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ for indexName, each := range uniqueKeyMap {
|
|
|
|
+ for _, columnName := range each {
|
|
|
|
+ uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for indexName, each := range normalKeyMap {
|
|
|
|
+ for _, columnName := range each {
|
|
|
|
+ normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ checkDuplicateUniqueIndex(uniqueIndex, e.Name)
|
|
|
|
+
|
|
|
|
+ list = append(list, &Table{
|
|
|
|
+ Name: stringx.From(e.Name),
|
|
|
|
+ PrimaryKey: primaryKey,
|
|
|
|
+ UniqueIndex: uniqueIndex,
|
|
|
|
+ Fields: fields,
|
|
|
|
+ })
|
|
}
|
|
}
|
|
|
|
|
|
- checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
|
|
|
|
- return &Table{
|
|
|
|
- Name: stringx.From(tableName),
|
|
|
|
- PrimaryKey: primaryKey,
|
|
|
|
- UniqueIndex: uniqueIndex,
|
|
|
|
- NormalIndex: normalIndex,
|
|
|
|
- Fields: fields,
|
|
|
|
- }, nil
|
|
|
|
|
|
+ return list, nil
|
|
}
|
|
}
|
|
|
|
|
|
-func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
|
|
|
|
|
|
+func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string) {
|
|
log := console.NewColorConsole()
|
|
log := console.NewColorConsole()
|
|
uniqueSet := collection.NewSet()
|
|
uniqueSet := collection.NewSet()
|
|
for k, i := range uniqueIndex {
|
|
for k, i := range uniqueIndex {
|
|
@@ -131,26 +172,9 @@ func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string
|
|
|
|
|
|
uniqueSet.AddStr(joinRet)
|
|
uniqueSet.AddStr(joinRet)
|
|
}
|
|
}
|
|
-
|
|
|
|
- normalIndexSet := collection.NewSet()
|
|
|
|
- for k, i := range normalIndex {
|
|
|
|
- var list []string
|
|
|
|
- for _, e := range i {
|
|
|
|
- list = append(list, e.Name.Source())
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- joinRet := strings.Join(list, ",")
|
|
|
|
- if normalIndexSet.Contains(joinRet) {
|
|
|
|
- log.Warning("table %s: duplicate index %s", tableName, joinRet)
|
|
|
|
- delete(normalIndex, k)
|
|
|
|
- continue
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- normalIndexSet.Add(joinRet)
|
|
|
|
- }
|
|
|
|
}
|
|
}
|
|
|
|
|
|
-func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
|
|
|
|
|
|
+func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, map[string]*Field, error) {
|
|
var (
|
|
var (
|
|
primaryKey Primary
|
|
primaryKey Primary
|
|
fieldM = make(map[string]*Field)
|
|
fieldM = make(map[string]*Field)
|
|
@@ -161,35 +185,35 @@ func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string)
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
|
|
|
|
- var comment string
|
|
|
|
- if column.Type.Comment != nil {
|
|
|
|
- comment = string(column.Type.Comment.Val)
|
|
|
|
- }
|
|
|
|
|
|
+ var (
|
|
|
|
+ comment string
|
|
|
|
+ isDefaultNull bool
|
|
|
|
+ )
|
|
|
|
|
|
- isDefaultNull := true
|
|
|
|
- if column.Type.NotNull {
|
|
|
|
- isDefaultNull = false
|
|
|
|
- } else {
|
|
|
|
- if column.Type.Default != nil {
|
|
|
|
|
|
+ if column.Constraint != nil {
|
|
|
|
+ comment = column.Constraint.Comment
|
|
|
|
+ isDefaultNull = !column.Constraint.HasDefaultValue
|
|
|
|
+ if column.Name == primaryColumn && column.Constraint.AutoIncrement {
|
|
isDefaultNull = false
|
|
isDefaultNull = false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
|
|
|
|
|
|
+ dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull)
|
|
if err != nil {
|
|
if err != nil {
|
|
return Primary{}, nil, err
|
|
return Primary{}, nil, err
|
|
}
|
|
}
|
|
|
|
|
|
var field Field
|
|
var field Field
|
|
- field.Name = stringx.From(column.Name.String())
|
|
|
|
- field.DataBaseType = column.Type.Type
|
|
|
|
|
|
+ field.Name = stringx.From(column.Name)
|
|
field.DataType = dataType
|
|
field.DataType = dataType
|
|
field.Comment = util.TrimNewLine(comment)
|
|
field.Comment = util.TrimNewLine(comment)
|
|
|
|
|
|
if field.Name.Source() == primaryColumn {
|
|
if field.Name.Source() == primaryColumn {
|
|
primaryKey = Primary{
|
|
primaryKey = Primary{
|
|
- Field: field,
|
|
|
|
- AutoIncrement: bool(column.Type.Autoincrement),
|
|
|
|
|
|
+ Field: field,
|
|
|
|
+ }
|
|
|
|
+ if column.Constraint != nil {
|
|
|
|
+ primaryKey.AutoIncrement = column.Constraint.AutoIncrement
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -198,60 +222,6 @@ func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string)
|
|
return primaryKey, fieldM, nil
|
|
return primaryKey, fieldM, nil
|
|
}
|
|
}
|
|
|
|
|
|
-func convertIndexes(indexes []*sqlparser.IndexDefinition) (string, map[string][]string, map[string][]string, error) {
|
|
|
|
- var primaryColumn string
|
|
|
|
- uniqueKeyMap := make(map[string][]string)
|
|
|
|
- normalKeyMap := make(map[string][]string)
|
|
|
|
-
|
|
|
|
- isCreateTimeOrUpdateTime := func(name string) bool {
|
|
|
|
- camelColumnName := stringx.From(name).ToCamel()
|
|
|
|
- // by default, createTime|updateTime findOne is not used.
|
|
|
|
- return camelColumnName == "CreateTime" || camelColumnName == "UpdateTime"
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- for _, index := range indexes {
|
|
|
|
- info := index.Info
|
|
|
|
- if info == nil {
|
|
|
|
- continue
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- indexName := index.Info.Name.String()
|
|
|
|
- if info.Primary {
|
|
|
|
- if len(index.Columns) > 1 {
|
|
|
|
- return "", nil, nil, errPrimaryKey
|
|
|
|
- }
|
|
|
|
- columnName := index.Columns[0].Column.String()
|
|
|
|
- if isCreateTimeOrUpdateTime(columnName) {
|
|
|
|
- continue
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- primaryColumn = columnName
|
|
|
|
- continue
|
|
|
|
- } else if info.Unique {
|
|
|
|
- for _, each := range index.Columns {
|
|
|
|
- columnName := each.Column.String()
|
|
|
|
- if isCreateTimeOrUpdateTime(columnName) {
|
|
|
|
- break
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- uniqueKeyMap[indexName] = append(uniqueKeyMap[indexName], columnName)
|
|
|
|
- }
|
|
|
|
- } else if info.Spatial {
|
|
|
|
- // do nothing
|
|
|
|
- } else {
|
|
|
|
- for _, each := range index.Columns {
|
|
|
|
- columnName := each.Column.String()
|
|
|
|
- if isCreateTimeOrUpdateTime(columnName) {
|
|
|
|
- break
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- normalKeyMap[indexName] = append(normalKeyMap[indexName], each.Column.String())
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- return primaryColumn, uniqueKeyMap, normalKeyMap, nil
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
// ContainsTime returns true if contains golang type time.Time
|
|
// ContainsTime returns true if contains golang type time.Time
|
|
func (t *Table) ContainsTime() bool {
|
|
func (t *Table) ContainsTime() bool {
|
|
for _, item := range t.Fields {
|
|
for _, item := range t.Fields {
|
|
@@ -265,14 +235,13 @@ func (t *Table) ContainsTime() bool {
|
|
// ConvertDataType converts mysql data type into golang data type
|
|
// ConvertDataType converts mysql data type into golang data type
|
|
func ConvertDataType(table *model.Table) (*Table, error) {
|
|
func ConvertDataType(table *model.Table) (*Table, error) {
|
|
isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
|
|
isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
|
|
- primaryDataType, err := converter.ConvertDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull)
|
|
|
|
|
|
+ primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull)
|
|
if err != nil {
|
|
if err != nil {
|
|
return nil, err
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
|
|
var reply Table
|
|
var reply Table
|
|
reply.UniqueIndex = map[string][]*Field{}
|
|
reply.UniqueIndex = map[string][]*Field{}
|
|
- reply.NormalIndex = map[string][]*Field{}
|
|
|
|
reply.Name = stringx.From(table.Table)
|
|
reply.Name = stringx.From(table.Table)
|
|
seqInIndex := 0
|
|
seqInIndex := 0
|
|
if table.PrimaryKey.Index != nil {
|
|
if table.PrimaryKey.Index != nil {
|
|
@@ -282,7 +251,6 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
|
reply.PrimaryKey = Primary{
|
|
reply.PrimaryKey = Primary{
|
|
Field: Field{
|
|
Field: Field{
|
|
Name: stringx.From(table.PrimaryKey.Name),
|
|
Name: stringx.From(table.PrimaryKey.Name),
|
|
- DataBaseType: table.PrimaryKey.DataType,
|
|
|
|
DataType: primaryDataType,
|
|
DataType: primaryDataType,
|
|
Comment: table.PrimaryKey.Comment,
|
|
Comment: table.PrimaryKey.Comment,
|
|
SeqInIndex: seqInIndex,
|
|
SeqInIndex: seqInIndex,
|
|
@@ -338,29 +306,6 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
|
reply.UniqueIndex[indexName] = list
|
|
reply.UniqueIndex[indexName] = list
|
|
}
|
|
}
|
|
|
|
|
|
- normalIndexSet := collection.NewSet()
|
|
|
|
- for indexName, each := range table.NormalIndex {
|
|
|
|
- var list []*Field
|
|
|
|
- var normalJoin []string
|
|
|
|
- for _, c := range each {
|
|
|
|
- list = append(list, fieldM[c.Name])
|
|
|
|
- normalJoin = append(normalJoin, c.Name)
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- normalKey := strings.Join(normalJoin, ",")
|
|
|
|
- if normalIndexSet.Contains(normalKey) {
|
|
|
|
- log.Warning("table %s: duplicate index, %s", table.Table, normalKey)
|
|
|
|
- continue
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- normalIndexSet.AddStr(normalKey)
|
|
|
|
- sort.Slice(list, func(i, j int) bool {
|
|
|
|
- return list[i].SeqInIndex < list[j].SeqInIndex
|
|
|
|
- })
|
|
|
|
-
|
|
|
|
- reply.NormalIndex[indexName] = list
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
return &reply, nil
|
|
return &reply, nil
|
|
}
|
|
}
|
|
|
|
|
|
@@ -368,7 +313,7 @@ func getTableFields(table *model.Table) (map[string]*Field, error) {
|
|
fieldM := make(map[string]*Field)
|
|
fieldM := make(map[string]*Field)
|
|
for _, each := range table.Columns {
|
|
for _, each := range table.Columns {
|
|
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
|
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
|
- dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
|
|
|
|
|
|
+ dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull)
|
|
if err != nil {
|
|
if err != nil {
|
|
return nil, err
|
|
return nil, err
|
|
}
|
|
}
|
|
@@ -379,7 +324,6 @@ func getTableFields(table *model.Table) (map[string]*Field, error) {
|
|
|
|
|
|
field := &Field{
|
|
field := &Field{
|
|
Name: stringx.From(each.Name),
|
|
Name: stringx.From(each.Name),
|
|
- DataBaseType: each.DataType,
|
|
|
|
DataType: dt,
|
|
DataType: dt,
|
|
Comment: each.Comment,
|
|
Comment: each.Comment,
|
|
SeqInIndex: columnSeqInIndex,
|
|
SeqInIndex: columnSeqInIndex,
|