1
0
Эх сурвалжийг харах

Fix issue #1127 (#1131)

* fix #1127

* fix #1127

* fixed unit test

* add go keyword converter

Co-authored-by: anqiansong <anqiansong@bytedance.com>
anqiansong 3 жил өмнө
parent
commit
44202acb18

+ 2 - 0
tools/goctl/model/sql/example/sql/user.sql

@@ -8,12 +8,14 @@ CREATE TABLE `user`
     `mobile`      varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',
     `gender`      char(5) COLLATE utf8mb4_general_ci      NOT NULL COMMENT '男|女|未公\r开',
     `nickname`    varchar(255) COLLATE utf8mb4_general_ci          DEFAULT '' COMMENT '用户昵称',
+    `type`    tinyint(1) COLLATE utf8mb4_general_ci          DEFAULT 0 COMMENT '用户类型',
     `create_time` timestamp NULL,
     `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
     PRIMARY KEY (`id`),
     UNIQUE KEY `name_index` (`name`),
     UNIQUE KEY `name_index2` (`name`),
     UNIQUE KEY `user_index` (`user`),
+    UNIQUE KEY `type_index` (`type`),
     UNIQUE KEY `mobile_index` (`mobile`)
 ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;
 

+ 4 - 2
tools/goctl/model/sql/model/informationschemamodel.go

@@ -6,6 +6,7 @@ import (
 
 	"github.com/tal-tech/go-zero/core/stores/sqlx"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
+	su "github.com/tal-tech/go-zero/tools/goctl/util"
 )
 
 const indexPri = "PRIMARY"
@@ -144,14 +145,15 @@ func (m *InformationSchemaModel) FindIndex(db, table, column string) ([]*DbIndex
 // Convert converts column data into Table
 func (c *ColumnData) Convert() (*Table, error) {
 	var table Table
-	table.Table = c.Table
-	table.Db = c.Db
+	table.Table = su.EscapeGolangKeyword(c.Table)
+	table.Db = su.EscapeGolangKeyword(c.Db)
 	table.Columns = c.Columns
 	table.UniqueIndex = map[string][]*Column{}
 	table.NormalIndex = map[string][]*Column{}
 
 	m := make(map[string][]*Column)
 	for _, each := range c.Columns {
+		each.Name = su.EscapeGolangKeyword(each.Name)
 		each.Comment = util.TrimNewLine(each.Comment)
 		if each.Index != nil {
 			m[each.Index.IndexName] = append(m[each.Index.IndexName], each)

+ 38 - 6
tools/goctl/model/sql/parser/parser.go

@@ -10,6 +10,7 @@ import (
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/converter"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
+	su "github.com/tal-tech/go-zero/tools/goctl/util"
 	"github.com/tal-tech/go-zero/tools/goctl/util/console"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 	"github.com/zeromicro/ddl-parser/parser"
@@ -49,11 +50,12 @@ type (
 // Parse parses ddl into golang structure
 func Parse(filename, database string) ([]*Table, error) {
 	p := parser.NewParser()
-	tables, err := p.From(filename)
+	ts, err := p.From(filename)
 	if err != nil {
 		return nil, err
 	}
 
+	tables := GetSafeTables(ts)
 	indexNameGen := func(column ...string) string {
 		return strings.Join(column, "_")
 	}
@@ -167,7 +169,7 @@ func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string
 
 		joinRet := strings.Join(list, ",")
 		if uniqueSet.Contains(joinRet) {
-			log.Warning("table %s: duplicate unique index %s", tableName, joinRet)
+			log.Warning("[checkDuplicateUniqueIndex]: table %s: duplicate unique index %s", tableName, joinRet)
 			delete(uniqueIndex, k)
 			continue
 		}
@@ -213,10 +215,10 @@ func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, ma
 		if column.Constraint != nil {
 			if column.Name == primaryColumn {
 				if !column.Constraint.AutoIncrement && dataType == "int64" {
-					log.Warning("%s: The primary key is recommended to add constraint `AUTO_INCREMENT`", column.Name)
+					log.Warning("[convertColumns]: The primary key %q is recommended to add constraint `AUTO_INCREMENT`", column.Name)
 				}
 			} else if column.Constraint.NotNull && !column.Constraint.HasDefaultValue {
-				log.Warning("%s: The column is recommended to add constraint `DEFAULT`", column.Name)
+				log.Warning("[convertColumns]: The column %q is recommended to add constraint `DEFAULT`", column.Name)
 			}
 		}
 
@@ -302,7 +304,7 @@ func ConvertDataType(table *model.Table) (*Table, error) {
 		if len(each) == 1 {
 			one := each[0]
 			if one.Name == table.PrimaryKey.Name {
-				log.Warning("table %s: duplicate unique index with primary key, %s", table.Table, one.Name)
+				log.Warning("[ConvertDataType]: table q%, duplicate unique index with primary key:  %q", table.Table, one.Name)
 				continue
 			}
 		}
@@ -316,7 +318,7 @@ func ConvertDataType(table *model.Table) (*Table, error) {
 
 		uniqueKey := strings.Join(uniqueJoin, ",")
 		if uniqueIndexSet.Contains(uniqueKey) {
-			log.Warning("table %s: duplicate unique index, %s", table.Table, uniqueKey)
+			log.Warning("[ConvertDataType]: table %q, duplicate unique index %q", table.Table, uniqueKey)
 			continue
 		}
 
@@ -351,3 +353,33 @@ func getTableFields(table *model.Table) (map[string]*Field, error) {
 	}
 	return fieldM, nil
 }
+
+func GetSafeTables(tables []*parser.Table) []*parser.Table {
+	var list []*parser.Table
+	for _, t := range tables {
+		table := GetSafeTable(t)
+		list = append(list, table)
+	}
+
+	return list
+}
+
+func GetSafeTable(table *parser.Table) *parser.Table {
+	table.Name = su.EscapeGolangKeyword(table.Name)
+	for _, c := range table.Columns {
+		c.Name = su.EscapeGolangKeyword(c.Name)
+	}
+
+	for _, e := range table.Constraints {
+		var uniqueKeys, primaryKeys []string
+		for _, u := range e.ColumnUniqueKey {
+			uniqueKeys = append(uniqueKeys, su.EscapeGolangKeyword(u))
+		}
+		for _, p := range e.ColumnPrimaryKey {
+			primaryKeys = append(primaryKeys, su.EscapeGolangKeyword(p))
+		}
+		e.ColumnUniqueKey = uniqueKeys
+		e.ColumnPrimaryKey = primaryKeys
+	}
+	return table
+}

+ 47 - 1
tools/goctl/util/string.go

@@ -1,6 +1,37 @@
 package util
 
-import "strings"
+import (
+	"strings"
+
+	"github.com/tal-tech/go-zero/tools/goctl/util/console"
+)
+
+var goKeyword = map[string]string{
+	"var":         "variable",
+	"const":       "constant",
+	"package":     "pkg",
+	"func":        "function",
+	"return":      "rtn",
+	"defer":       "dfr",
+	"go":          "goo",
+	"select":      "slt",
+	"struct":      "structure",
+	"interface":   "itf",
+	"chan":        "channel",
+	"type":        "tp",
+	"map":         "mp",
+	"range":       "rg",
+	"break":       "brk",
+	"case":        "caz",
+	"continue":    "ctn",
+	"for":         "fr",
+	"fallthrough": "fth",
+	"else":        "es",
+	"if":          "ef",
+	"switch":      "swt",
+	"goto":        "gt",
+	"default":     "dft",
+}
 
 // Title returns a string value with s[0] which has been convert into upper case that
 // there are not empty input text
@@ -64,3 +95,18 @@ func isLetter(r rune) bool {
 func isNumber(r rune) bool {
 	return '0' <= r && r <= '9'
 }
+
+func EscapeGolangKeyword(s string) string {
+	if !isGolangKeyword(s) {
+		return s
+	}
+
+	r := goKeyword[s]
+	console.Info("[EscapeGolangKeyword]: go keyword is forbidden %q, converted into %q", s, r)
+	return r
+}
+
+func isGolangKeyword(s string) bool {
+	_, ok := goKeyword[s]
+	return ok
+}

+ 8 - 0
tools/goctl/util/string_test.go

@@ -1,6 +1,7 @@
 package util
 
 import (
+	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
@@ -64,3 +65,10 @@ func TestSafeString(t *testing.T) {
 		assert.Equal(t, e.expected, SafeString(e.input))
 	}
 }
+
+func TestEscapeGoKeyword(t *testing.T) {
+	for k := range goKeyword {
+		assert.Equal(t, goKeyword[k], EscapeGolangKeyword(k))
+		assert.False(t, isGolangKeyword(strings.Title(k)))
+	}
+}