Quellcode durchsuchen

Fix issues: #725, #740 (#813)

* Fix issues: #725, #740

* Update filed sort

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
anqiansong vor 3 Jahren
Ursprung
Commit
9b2a279948

+ 1 - 1
go.mod

@@ -35,8 +35,8 @@ require (
 	github.com/spaolacci/murmur3 v1.1.0
 	github.com/stretchr/testify v1.7.0
 	github.com/urfave/cli v1.22.5
-	github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
 	github.com/zeromicro/antlr v0.0.1
+	github.com/zeromicro/ddl-parser v0.0.0-20210712021150-63520aca7348 // indirect
 	go.etcd.io/etcd/api/v3 v3.5.0
 	go.etcd.io/etcd/client/v3 v3.5.0
 	go.uber.org/automaxprocs v1.3.0

+ 6 - 2
go.sum

@@ -15,6 +15,8 @@ github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGn
 github.com/alicebob/miniredis/v2 v2.14.1 h1:GjlbSeoJ24bzdLRs13HoMEeaRZx9kg5nHoRW7QV/nCs=
 github.com/alicebob/miniredis/v2 v2.14.1/go.mod h1:uS970Sw5Gs9/iK3yBg0l9Uj9s25wXxSpQUE9EaJ/Blg=
 github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
+github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec h1:EEyRvzmpEUZ+I8WmD5cw/vY8EqhambkOqy5iFr0908A=
+github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY=
 github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
@@ -225,8 +227,6 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/urfave/cli v1.22.5 h1:lNq9sAHXK2qfdI8W+GRItjCEkI+2oR4d+MEHy1CKXoU=
 github.com/urfave/cli v1.22.5/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
-github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 h1:zzrxE1FKn5ryBNl9eKOeqQ58Y/Qpo3Q9QNxKHX5uzzQ=
-github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2/go.mod h1:hzfGeIUDq/j97IG+FhNqkowIyEcD88LrW6fyU3K3WqY=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -234,6 +234,10 @@ github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb h1:ZkM6LRnq40pR1Ox
 github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb/go.mod h1:gqRgreBUhTSL0GeU64rtZ3Uq3wtjOa/TB2YfrtkCbVQ=
 github.com/zeromicro/antlr v0.0.1 h1:CQpIn/dc0pUjgGQ81y98s/NGOm2Hfru2NNio2I9mQgk=
 github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
+github.com/zeromicro/ddl-parser v0.0.0-20210710132903-bc9dbb9789b1 h1:zItUIfobEHTYD9X0fAt9QWEWIFWDa8CypF+Z62zIR+M=
+github.com/zeromicro/ddl-parser v0.0.0-20210710132903-bc9dbb9789b1/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8=
+github.com/zeromicro/ddl-parser v0.0.0-20210712021150-63520aca7348 h1:OhxL9tn28gDeJVzreIUiE5oVxZCjL3tBJ0XBNw8p5R8=
+github.com/zeromicro/ddl-parser v0.0.0-20210712021150-63520aca7348/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8=
 go.etcd.io/etcd/api/v3 v3.5.0 h1:GsV3S+OfZEOCNXdtNkBSR7kgLobAa/SO6tCxRa0GAYw=
 go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs=
 go.etcd.io/etcd/client/pkg/v3 v3.5.0 h1:2aQv6F436YnN7I4VbI8PPYrBhu+SmrTaADcf8Mi/6PU=

+ 7 - 11
tools/goctl/model/sql/command/command.go

@@ -2,7 +2,6 @@ package command
 
 import (
 	"errors"
-	"io/ioutil"
 	"path/filepath"
 	"strings"
 
@@ -76,22 +75,19 @@ func fromDDl(src, dir string, cfg *config.Config, cache, idea bool) error {
 		return errNotMatched
 	}
 
-	var source []string
+	generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log))
+	if err != nil {
+		return err
+	}
+
 	for _, file := range files {
-		data, err := ioutil.ReadFile(file)
+		err = generator.StartFromDDL(file, cache)
 		if err != nil {
 			return err
 		}
-
-		source = append(source, string(data))
-	}
-
-	generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log))
-	if err != nil {
-		return err
 	}
 
-	return generator.StartFromDDL(strings.Join(source, "\n"), cache)
+	return nil
 }
 
 func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bool) error {

+ 58 - 4
tools/goctl/model/sql/converter/types.go

@@ -3,9 +3,53 @@ package converter
 import (
 	"fmt"
 	"strings"
+
+	"github.com/zeromicro/ddl-parser/parser"
 )
 
-var commonMysqlDataTypeMap = map[string]string{
+var commonMysqlDataTypeMap = map[int]string{
+	// For consistency, all integer types are converted to int64
+	// number
+	parser.Bool:      "int64",
+	parser.Boolean:   "int64",
+	parser.TinyInt:   "int64",
+	parser.SmallInt:  "int64",
+	parser.MediumInt: "int64",
+	parser.Int:       "int64",
+	parser.MiddleInt: "int64",
+	parser.Int1:      "int64",
+	parser.Int2:      "int64",
+	parser.Int3:      "int64",
+	parser.Int4:      "int64",
+	parser.Int8:      "int64",
+	parser.Integer:   "int64",
+	parser.BigInt:    "int64",
+	parser.Float:     "float64",
+	parser.Float4:    "float64",
+	parser.Float8:    "float64",
+	parser.Double:    "float64",
+	parser.Decimal:   "float64",
+	// date&time
+	parser.Date:      "time.Time",
+	parser.DateTime:  "time.Time",
+	parser.Timestamp: "time.Time",
+	parser.Time:      "string",
+	parser.Year:      "int64",
+	// string
+	parser.Char:       "string",
+	parser.VarChar:    "string",
+	parser.Binary:     "string",
+	parser.VarBinary:  "string",
+	parser.TinyText:   "string",
+	parser.Text:       "string",
+	parser.MediumText: "string",
+	parser.LongText:   "string",
+	parser.Enum:       "string",
+	parser.Set:        "string",
+	parser.Json:       "string",
+}
+
+var commonMysqlDataTypeMap2 = map[string]string{
 	// For consistency, all integer types are converted to int64
 	// number
 	"bool":      "int64",
@@ -40,10 +84,20 @@ var commonMysqlDataTypeMap = map[string]string{
 }
 
 // ConvertDataType converts mysql column type into golang type
-func ConvertDataType(dataBaseType string, isDefaultNull bool) (string, error) {
-	tp, ok := commonMysqlDataTypeMap[strings.ToLower(dataBaseType)]
+func ConvertDataType(dataBaseType int, isDefaultNull bool) (string, error) {
+	tp, ok := commonMysqlDataTypeMap[dataBaseType]
+	if !ok {
+		return "", fmt.Errorf("unsupported database type: %v", dataBaseType)
+	}
+
+	return mayConvertNullType(tp, isDefaultNull), nil
+}
+
+// ConvertStringDataType converts mysql column type into golang type
+func ConvertStringDataType(dataBaseType string, isDefaultNull bool) (string, error) {
+	tp, ok := commonMysqlDataTypeMap2[strings.ToLower(dataBaseType)]
 	if !ok {
-		return "", fmt.Errorf("unexpected database type: %s", dataBaseType)
+		return "", fmt.Errorf("unsupported database type: %s", dataBaseType)
 	}
 
 	return mayConvertNullType(tp, isDefaultNull), nil

+ 5 - 7
tools/goctl/model/sql/converter/types_test.go

@@ -4,25 +4,23 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/zeromicro/ddl-parser/parser"
 )
 
 func TestConvertDataType(t *testing.T) {
-	v, err := ConvertDataType("tinyint", false)
+	v, err := ConvertDataType(parser.TinyInt, false)
 	assert.Nil(t, err)
 	assert.Equal(t, "int64", v)
 
-	v, err = ConvertDataType("tinyint", true)
+	v, err = ConvertDataType(parser.TinyInt, true)
 	assert.Nil(t, err)
 	assert.Equal(t, "sql.NullInt64", v)
 
-	v, err = ConvertDataType("timestamp", false)
+	v, err = ConvertDataType(parser.Timestamp, false)
 	assert.Nil(t, err)
 	assert.Equal(t, "time.Time", v)
 
-	v, err = ConvertDataType("timestamp", true)
+	v, err = ConvertDataType(parser.Timestamp, true)
 	assert.Nil(t, err)
 	assert.Equal(t, "sql.NullTime", v)
-
-	_, err = ConvertDataType("float32", false)
-	assert.NotNil(t, err)
 }

+ 10 - 11
tools/goctl/model/sql/gen/gen.go

@@ -90,8 +90,8 @@ func newDefaultOption() Option {
 	}
 }
 
-func (g *defaultGenerator) StartFromDDL(source string, withCache bool) error {
-	modelList, err := g.genFromDDL(source, withCache)
+func (g *defaultGenerator) StartFromDDL(filename string, withCache bool) error {
+	modelList, err := g.genFromDDL(filename, withCache)
 	if err != nil {
 		return err
 	}
@@ -174,21 +174,20 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error {
 }
 
 // ret1: key-table name,value-code
-func (g *defaultGenerator) genFromDDL(source string, withCache bool) (map[string]string, error) {
-	ddlList := g.split(source)
+func (g *defaultGenerator) genFromDDL(filename string, withCache bool) (map[string]string, error) {
 	m := make(map[string]string)
-	for _, ddl := range ddlList {
-		table, err := parser.Parse(ddl)
-		if err != nil {
-			return nil, err
-		}
+	tables, err := parser.Parse(filename)
+	if err != nil {
+		return nil, err
+	}
 
-		code, err := g.genModel(*table, withCache)
+	for _, e := range tables {
+		code, err := g.genModel(*e, withCache)
 		if err != nil {
 			return nil, err
 		}
 
-		m[table.Name.Source()] = code
+		m[e.Name.Source()] = code
 	}
 
 	return m, nil

+ 15 - 4
tools/goctl/model/sql/gen/gen_test.go

@@ -2,6 +2,7 @@ package gen
 
 import (
 	"database/sql"
+	"io/ioutil"
 	"os"
 	"path/filepath"
 	"strings"
@@ -20,6 +21,11 @@ var source = "CREATE TABLE `test_user` (\n  `id` bigint NOT NULL AUTO_INCREMENT,
 func TestCacheModel(t *testing.T) {
 	logx.Disable()
 	_ = Clean()
+
+	sqlFile := filepath.Join(t.TempDir(), "tmp.sql")
+	err := ioutil.WriteFile(sqlFile, []byte(source), 0777)
+	assert.Nil(t, err)
+
 	dir := filepath.Join(t.TempDir(), "./testmodel")
 	cacheDir := filepath.Join(dir, "cache")
 	noCacheDir := filepath.Join(dir, "nocache")
@@ -28,7 +34,7 @@ func TestCacheModel(t *testing.T) {
 	})
 	assert.Nil(t, err)
 
-	err = g.StartFromDDL(source, true)
+	err = g.StartFromDDL(sqlFile, true)
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 		_, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go"))
@@ -39,7 +45,7 @@ func TestCacheModel(t *testing.T) {
 	})
 	assert.Nil(t, err)
 
-	err = g.StartFromDDL(source, false)
+	err = g.StartFromDDL(sqlFile, false)
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 		_, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go"))
@@ -50,6 +56,11 @@ func TestCacheModel(t *testing.T) {
 func TestNamingModel(t *testing.T) {
 	logx.Disable()
 	_ = Clean()
+
+	sqlFile := filepath.Join(t.TempDir(), "tmp.sql")
+	err := ioutil.WriteFile(sqlFile, []byte(source), 0777)
+	assert.Nil(t, err)
+
 	dir, _ := filepath.Abs("./testmodel")
 	camelDir := filepath.Join(dir, "camel")
 	snakeDir := filepath.Join(dir, "snake")
@@ -61,7 +72,7 @@ func TestNamingModel(t *testing.T) {
 	})
 	assert.Nil(t, err)
 
-	err = g.StartFromDDL(source, true)
+	err = g.StartFromDDL(sqlFile, true)
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 		_, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go"))
@@ -72,7 +83,7 @@ func TestNamingModel(t *testing.T) {
 	})
 	assert.Nil(t, err)
 
-	err = g.StartFromDDL(source, true)
+	err = g.StartFromDDL(sqlFile, true)
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 		_, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go"))

+ 22 - 29
tools/goctl/model/sql/gen/keys_test.go

@@ -11,32 +11,28 @@ import (
 
 func TestGenCacheKeys(t *testing.T) {
 	primaryField := &parser.Field{
-		Name:         stringx.From("id"),
-		DataBaseType: "bigint",
-		DataType:     "int64",
-		Comment:      "自增id",
-		SeqInIndex:   1,
+		Name:       stringx.From("id"),
+		DataType:   "int64",
+		Comment:    "自增id",
+		SeqInIndex: 1,
 	}
 	mobileField := &parser.Field{
-		Name:         stringx.From("mobile"),
-		DataBaseType: "varchar",
-		DataType:     "string",
-		Comment:      "手机号",
-		SeqInIndex:   1,
+		Name:       stringx.From("mobile"),
+		DataType:   "string",
+		Comment:    "手机号",
+		SeqInIndex: 1,
 	}
 	classField := &parser.Field{
-		Name:         stringx.From("class"),
-		DataBaseType: "varchar",
-		DataType:     "string",
-		Comment:      "班级",
-		SeqInIndex:   1,
+		Name:       stringx.From("class"),
+		DataType:   "string",
+		Comment:    "班级",
+		SeqInIndex: 1,
 	}
 	nameField := &parser.Field{
-		Name:         stringx.From("name"),
-		DataBaseType: "varchar",
-		DataType:     "string",
-		Comment:      "姓名",
-		SeqInIndex:   2,
+		Name:       stringx.From("name"),
+		DataType:   "string",
+		Comment:    "姓名",
+		SeqInIndex: 2,
 	}
 	primariCacheKey, uniqueCacheKey := genCacheKeys(parser.Table{
 		Name: stringx.From("user"),
@@ -53,23 +49,20 @@ func TestGenCacheKeys(t *testing.T) {
 				nameField,
 			},
 		},
-		NormalIndex: nil,
 		Fields: []*parser.Field{
 			primaryField,
 			mobileField,
 			classField,
 			nameField,
 			{
-				Name:         stringx.From("createTime"),
-				DataBaseType: "timestamp",
-				DataType:     "time.Time",
-				Comment:      "创建时间",
+				Name:     stringx.From("createTime"),
+				DataType: "time.Time",
+				Comment:  "创建时间",
 			},
 			{
-				Name:         stringx.From("updateTime"),
-				DataBaseType: "timestamp",
-				DataType:     "time.Time",
-				Comment:      "更新时间",
+				Name:     stringx.From("updateTime"),
+				DataType: "time.Time",
+				Comment:  "更新时间",
 			},
 		},
 	})

+ 0 - 11
tools/goctl/model/sql/parser/error.go

@@ -1,11 +0,0 @@
-package parser
-
-import (
-	"errors"
-)
-
-var (
-	errUnsupportDDL      = errors.New("unexpected type")
-	errTableBodyNotFound = errors.New("create table spec not found")
-	errPrimaryKey        = errors.New("unexpected join primary key")
-)

+ 109 - 165
tools/goctl/model/sql/parser/parser.go

@@ -2,6 +2,7 @@ package parser
 
 import (
 	"fmt"
+	"path/filepath"
 	"sort"
 	"strings"
 
@@ -11,7 +12,7 @@ import (
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
 	"github.com/tal-tech/go-zero/tools/goctl/util/console"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
-	"github.com/xwb1989/sqlparser"
+	"github.com/zeromicro/ddl-parser/parser"
 )
 
 const timeImport = "time.Time"
@@ -22,7 +23,6 @@ type (
 		Name        stringx.String
 		PrimaryKey  Primary
 		UniqueIndex map[string][]*Field
-		NormalIndex map[string][]*Field
 		Fields      []*Field
 	}
 
@@ -35,7 +35,6 @@ type (
 	// Field describes a table field
 	Field struct {
 		Name            stringx.String
-		DataBaseType    string
 		DataType        string
 		Comment         string
 		SeqInIndex      int
@@ -47,73 +46,115 @@ type (
 )
 
 // Parse parses ddl into golang structure
-func Parse(ddl string) (*Table, error) {
-	stmt, err := sqlparser.ParseStrictDDL(ddl)
+func Parse(filename string) ([]*Table, error) {
+	p := parser.NewParser()
+	tables, err := p.From(filename)
 	if err != nil {
 		return nil, err
 	}
 
-	ddlStmt, ok := stmt.(*sqlparser.DDL)
-	if !ok {
-		return nil, errUnsupportDDL
+	indexNameGen := func(column ...string) string {
+		return strings.Join(column, "_")
 	}
 
-	action := ddlStmt.Action
-	if action != sqlparser.CreateStr {
-		return nil, fmt.Errorf("expected [CREATE] action,but found: %s", action)
-	}
+	prefix := filepath.Base(filename)
+	var list []*Table
+	for _, e := range tables {
+		columns := e.Columns
 
-	tableName := ddlStmt.NewName.Name.String()
-	tableSpec := ddlStmt.TableSpec
-	if tableSpec == nil {
-		return nil, errTableBodyNotFound
-	}
+		var (
+			primaryColumnSet = collection.NewSet()
 
-	columns := tableSpec.Columns
-	indexes := tableSpec.Indexes
-	primaryColumn, uniqueKeyMap, normalKeyMap, err := convertIndexes(indexes)
-	if err != nil {
-		return nil, err
-	}
+			primaryColumn string
+			uniqueKeyMap  = make(map[string][]string)
+			normalKeyMap  = make(map[string][]string)
+		)
 
-	primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
-	if err != nil {
-		return nil, err
-	}
+		for _, column := range columns {
+			if column.Constraint != nil {
+				if column.Constraint.Primary {
+					primaryColumnSet.AddStr(column.Name)
+				}
 
-	var fields []*Field
-	for _, e := range fieldM {
-		fields = append(fields, e)
-	}
+				if column.Constraint.Unique {
+					indexName := indexNameGen(column.Name, "unique")
+					uniqueKeyMap[indexName] = []string{column.Name}
+				}
 
-	var (
-		uniqueIndex = make(map[string][]*Field)
-		normalIndex = make(map[string][]*Field)
-	)
+				if column.Constraint.Key {
+					indexName := indexNameGen(column.Name, "idx")
+					uniqueKeyMap[indexName] = []string{column.Name}
+				}
+			}
+		}
 
-	for indexName, each := range uniqueKeyMap {
-		for _, columnName := range each {
-			uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
+		for _, e := range e.Constraints {
+			if len(e.ColumnPrimaryKey) > 1 {
+				return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
+			}
+
+			if len(e.ColumnPrimaryKey) == 1 {
+				primaryColumn = e.ColumnPrimaryKey[0]
+				primaryColumnSet.AddStr(e.ColumnPrimaryKey[0])
+			}
+
+			if len(e.ColumnUniqueKey) > 0 {
+				list := append([]string(nil), e.ColumnUniqueKey...)
+				list = append(list, "unique")
+				indexName := indexNameGen(list...)
+				uniqueKeyMap[indexName] = e.ColumnUniqueKey
+			}
 		}
-	}
 
-	for indexName, each := range normalKeyMap {
-		for _, columnName := range each {
-			normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
+		if primaryColumnSet.Count() > 1 {
+			return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
 		}
+
+		primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
+		if err != nil {
+			return nil, err
+		}
+
+		var fields []*Field
+		// sort
+		for _, c := range columns {
+			field, ok := fieldM[c.Name]
+			if ok {
+				fields = append(fields, field)
+			}
+		}
+
+		var (
+			uniqueIndex = make(map[string][]*Field)
+			normalIndex = make(map[string][]*Field)
+		)
+
+		for indexName, each := range uniqueKeyMap {
+			for _, columnName := range each {
+				uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
+			}
+		}
+
+		for indexName, each := range normalKeyMap {
+			for _, columnName := range each {
+				normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
+			}
+		}
+
+		checkDuplicateUniqueIndex(uniqueIndex, e.Name)
+
+		list = append(list, &Table{
+			Name:        stringx.From(e.Name),
+			PrimaryKey:  primaryKey,
+			UniqueIndex: uniqueIndex,
+			Fields:      fields,
+		})
 	}
 
-	checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
-	return &Table{
-		Name:        stringx.From(tableName),
-		PrimaryKey:  primaryKey,
-		UniqueIndex: uniqueIndex,
-		NormalIndex: normalIndex,
-		Fields:      fields,
-	}, nil
+	return list, nil
 }
 
-func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
+func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string) {
 	log := console.NewColorConsole()
 	uniqueSet := collection.NewSet()
 	for k, i := range uniqueIndex {
@@ -131,26 +172,9 @@ func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string
 
 		uniqueSet.AddStr(joinRet)
 	}
-
-	normalIndexSet := collection.NewSet()
-	for k, i := range normalIndex {
-		var list []string
-		for _, e := range i {
-			list = append(list, e.Name.Source())
-		}
-
-		joinRet := strings.Join(list, ",")
-		if normalIndexSet.Contains(joinRet) {
-			log.Warning("table %s: duplicate index %s", tableName, joinRet)
-			delete(normalIndex, k)
-			continue
-		}
-
-		normalIndexSet.Add(joinRet)
-	}
 }
 
-func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
+func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, map[string]*Field, error) {
 	var (
 		primaryKey Primary
 		fieldM     = make(map[string]*Field)
@@ -161,35 +185,35 @@ func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string)
 			continue
 		}
 
-		var comment string
-		if column.Type.Comment != nil {
-			comment = string(column.Type.Comment.Val)
-		}
+		var (
+			comment       string
+			isDefaultNull bool
+		)
 
-		isDefaultNull := true
-		if column.Type.NotNull {
-			isDefaultNull = false
-		} else {
-			if column.Type.Default != nil {
+		if column.Constraint != nil {
+			comment = column.Constraint.Comment
+			isDefaultNull = !column.Constraint.HasDefaultValue
+			if column.Name == primaryColumn && column.Constraint.AutoIncrement {
 				isDefaultNull = false
 			}
 		}
 
-		dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
+		dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull)
 		if err != nil {
 			return Primary{}, nil, err
 		}
 
 		var field Field
-		field.Name = stringx.From(column.Name.String())
-		field.DataBaseType = column.Type.Type
+		field.Name = stringx.From(column.Name)
 		field.DataType = dataType
 		field.Comment = util.TrimNewLine(comment)
 
 		if field.Name.Source() == primaryColumn {
 			primaryKey = Primary{
-				Field:         field,
-				AutoIncrement: bool(column.Type.Autoincrement),
+				Field: field,
+			}
+			if column.Constraint != nil {
+				primaryKey.AutoIncrement = column.Constraint.AutoIncrement
 			}
 		}
 
@@ -198,60 +222,6 @@ func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string)
 	return primaryKey, fieldM, nil
 }
 
-func convertIndexes(indexes []*sqlparser.IndexDefinition) (string, map[string][]string, map[string][]string, error) {
-	var primaryColumn string
-	uniqueKeyMap := make(map[string][]string)
-	normalKeyMap := make(map[string][]string)
-
-	isCreateTimeOrUpdateTime := func(name string) bool {
-		camelColumnName := stringx.From(name).ToCamel()
-		// by default, createTime|updateTime findOne is not used.
-		return camelColumnName == "CreateTime" || camelColumnName == "UpdateTime"
-	}
-
-	for _, index := range indexes {
-		info := index.Info
-		if info == nil {
-			continue
-		}
-
-		indexName := index.Info.Name.String()
-		if info.Primary {
-			if len(index.Columns) > 1 {
-				return "", nil, nil, errPrimaryKey
-			}
-			columnName := index.Columns[0].Column.String()
-			if isCreateTimeOrUpdateTime(columnName) {
-				continue
-			}
-
-			primaryColumn = columnName
-			continue
-		} else if info.Unique {
-			for _, each := range index.Columns {
-				columnName := each.Column.String()
-				if isCreateTimeOrUpdateTime(columnName) {
-					break
-				}
-
-				uniqueKeyMap[indexName] = append(uniqueKeyMap[indexName], columnName)
-			}
-		} else if info.Spatial {
-			// do nothing
-		} else {
-			for _, each := range index.Columns {
-				columnName := each.Column.String()
-				if isCreateTimeOrUpdateTime(columnName) {
-					break
-				}
-
-				normalKeyMap[indexName] = append(normalKeyMap[indexName], each.Column.String())
-			}
-		}
-	}
-	return primaryColumn, uniqueKeyMap, normalKeyMap, nil
-}
-
 // ContainsTime returns true if contains golang type time.Time
 func (t *Table) ContainsTime() bool {
 	for _, item := range t.Fields {
@@ -265,14 +235,13 @@ func (t *Table) ContainsTime() bool {
 // ConvertDataType converts mysql data type into golang data type
 func ConvertDataType(table *model.Table) (*Table, error) {
 	isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
-	primaryDataType, err := converter.ConvertDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull)
+	primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull)
 	if err != nil {
 		return nil, err
 	}
 
 	var reply Table
 	reply.UniqueIndex = map[string][]*Field{}
-	reply.NormalIndex = map[string][]*Field{}
 	reply.Name = stringx.From(table.Table)
 	seqInIndex := 0
 	if table.PrimaryKey.Index != nil {
@@ -282,7 +251,6 @@ func ConvertDataType(table *model.Table) (*Table, error) {
 	reply.PrimaryKey = Primary{
 		Field: Field{
 			Name:            stringx.From(table.PrimaryKey.Name),
-			DataBaseType:    table.PrimaryKey.DataType,
 			DataType:        primaryDataType,
 			Comment:         table.PrimaryKey.Comment,
 			SeqInIndex:      seqInIndex,
@@ -338,29 +306,6 @@ func ConvertDataType(table *model.Table) (*Table, error) {
 		reply.UniqueIndex[indexName] = list
 	}
 
-	normalIndexSet := collection.NewSet()
-	for indexName, each := range table.NormalIndex {
-		var list []*Field
-		var normalJoin []string
-		for _, c := range each {
-			list = append(list, fieldM[c.Name])
-			normalJoin = append(normalJoin, c.Name)
-		}
-
-		normalKey := strings.Join(normalJoin, ",")
-		if normalIndexSet.Contains(normalKey) {
-			log.Warning("table %s: duplicate index, %s", table.Table, normalKey)
-			continue
-		}
-
-		normalIndexSet.AddStr(normalKey)
-		sort.Slice(list, func(i, j int) bool {
-			return list[i].SeqInIndex < list[j].SeqInIndex
-		})
-
-		reply.NormalIndex[indexName] = list
-	}
-
 	return &reply, nil
 }
 
@@ -368,7 +313,7 @@ func getTableFields(table *model.Table) (map[string]*Field, error) {
 	fieldM := make(map[string]*Field)
 	for _, each := range table.Columns {
 		isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
-		dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
+		dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull)
 		if err != nil {
 			return nil, err
 		}
@@ -379,7 +324,6 @@ func getTableFields(table *model.Table) (map[string]*Field, error) {
 
 		field := &Field{
 			Name:            stringx.From(each.Name),
-			DataBaseType:    each.DataType,
 			DataType:        dt,
 			Comment:         each.Comment,
 			SeqInIndex:      columnSeqInIndex,

+ 22 - 63
tools/goctl/model/sql/parser/parser_test.go

@@ -1,88 +1,47 @@
 package parser
 
 import (
-	"sort"
+	"io/ioutil"
+	"path/filepath"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
-	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
 
 func TestParsePlainText(t *testing.T) {
-	_, err := Parse("plain text")
+	sqlFile := filepath.Join(t.TempDir(), "tmp.sql")
+	err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0777)
+	assert.Nil(t, err)
+
+	_, err = Parse(sqlFile)
 	assert.NotNil(t, err)
 }
 
 func TestParseSelect(t *testing.T) {
-	_, err := Parse("select * from user")
-	assert.Equal(t, errUnsupportDDL, err)
+	sqlFile := filepath.Join(t.TempDir(), "tmp.sql")
+	err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0777)
+	assert.Nil(t, err)
+
+	tables, err := Parse(sqlFile)
+	assert.Nil(t, err)
+	assert.Equal(t, 0, len(tables))
 }
 
 func TestParseCreateTable(t *testing.T) {
-	table, err := Parse("CREATE TABLE `test_user` (\n  `id` bigint NOT NULL AUTO_INCREMENT,\n  `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机  号',\n  `class` bigint NOT NULL comment '班级',\n  `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n  名',\n  `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n  `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n  PRIMARY KEY (`id`),\n  UNIQUE KEY `mobile_unique` (`mobile`),\n  UNIQUE KEY `class_name_unique` (`class`,`name`),\n  KEY `create_index` (`create_time`),\n  KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;")
+	sqlFile := filepath.Join(t.TempDir(), "tmp.sql")
+	err := ioutil.WriteFile(sqlFile, []byte("CREATE TABLE `test_user` (\n  `id` bigint NOT NULL AUTO_INCREMENT,\n  `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机  号',\n  `class` bigint NOT NULL comment '班级',\n  `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n  名',\n  `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n  `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n  PRIMARY KEY (`id`),\n  UNIQUE KEY `mobile_unique` (`mobile`),\n  UNIQUE KEY `class_name_unique` (`class`,`name`),\n  KEY `create_index` (`create_time`),\n  KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"), 0777)
+	assert.Nil(t, err)
+
+	tables, err := Parse(sqlFile)
+	assert.Equal(t, 1, len(tables))
+	table := tables[0]
 	assert.Nil(t, err)
 	assert.Equal(t, "test_user", table.Name.Source())
 	assert.Equal(t, "id", table.PrimaryKey.Name.Source())
 	assert.Equal(t, true, table.ContainsTime())
-	assert.Equal(t, true, func() bool {
-		mobileUniqueIndex, ok := table.UniqueIndex["mobile_unique"]
-		if !ok {
-			return false
-		}
-
-		classNameUniqueIndex, ok := table.UniqueIndex["class_name_unique"]
-		if !ok {
-			return false
-		}
-
-		equal := func(f1, f2 []*Field) bool {
-			sort.Slice(f1, func(i, j int) bool {
-				return f1[i].Name.Source() < f1[j].Name.Source()
-			})
-			sort.Slice(f2, func(i, j int) bool {
-				return f2[i].Name.Source() < f2[j].Name.Source()
-			})
-
-			if len(f2) != len(f2) {
-				return false
-			}
-
-			for index, f := range f1 {
-				if f1[index].Name.Source() != f.Name.Source() {
-					return false
-				}
-			}
-			return true
-		}
-
-		if !equal(mobileUniqueIndex, []*Field{
-			{
-				Name:         stringx.From("mobile"),
-				DataBaseType: "varchar",
-				DataType:     "string",
-				SeqInIndex:   1,
-			},
-		}) {
-			return false
-		}
-
-		return equal(classNameUniqueIndex, []*Field{
-			{
-				Name:         stringx.From("class"),
-				DataBaseType: "bigint",
-				DataType:     "int64",
-				SeqInIndex:   1,
-			},
-			{
-				Name:         stringx.From("name"),
-				DataBaseType: "varchar",
-				DataType:     "string",
-				SeqInIndex:   2,
-			},
-		})
-	}())
+	assert.Equal(t, 2, len(table.UniqueIndex))
 	assert.True(t, func() bool {
 		for _, e := range table.Fields {
 			if e.Comment != util.TrimNewLine(e.Comment) {