浏览代码

Added database prefix of cache key. (#835)

fangjianwei 3 年之前
父节点
当前提交
476026e393

+ 4 - 0
tools/goctl/goctl.go

@@ -419,6 +419,10 @@ var (
 									Name:  "idea",
 									Usage: "for idea plugin [optional]",
 								},
+								cli.StringFlag{
+									Name:  "database, db",
+									Usage: "the name of database [optional]",
+								},
 							},
 							Action: model.MysqlDDL,
 						},

+ 1 - 0
tools/goctl/model/sql/README.MD

@@ -264,6 +264,7 @@ OPTIONS:
        --style value          the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]
        --cache, -c            generate code with cache [optional]
        --idea                 for idea plugin [optional]
+       --database, -db        the name of database [optional]
 	```
 
   * datasource

+ 5 - 3
tools/goctl/model/sql/command/command.go

@@ -24,6 +24,7 @@ const (
 	flagURL   = "url"
 	flagTable = "table"
 	flagStyle = "style"
+	flagDatabase = "database"
 )
 
 var errNotMatched = errors.New("sql not matched")
@@ -35,12 +36,13 @@ func MysqlDDL(ctx *cli.Context) error {
 	cache := ctx.Bool(flagCache)
 	idea := ctx.Bool(flagIdea)
 	style := ctx.String(flagStyle)
+	database := ctx.String(flagDatabase)
 	cfg, err := config.NewConfig(style)
 	if err != nil {
 		return err
 	}
 
-	return fromDDl(src, dir, cfg, cache, idea)
+	return fromDDl(src, dir, cfg, cache, idea, database)
 }
 
 // MyDataSource generates model code from datasource
@@ -59,7 +61,7 @@ func MyDataSource(ctx *cli.Context) error {
 	return fromDataSource(url, pattern, dir, cfg, cache, idea)
 }
 
-func fromDDl(src, dir string, cfg *config.Config, cache, idea bool) error {
+func fromDDl(src, dir string, cfg *config.Config, cache, idea bool, database string) error {
 	log := console.NewConsole(idea)
 	src = strings.TrimSpace(src)
 	if len(src) == 0 {
@@ -81,7 +83,7 @@ func fromDDl(src, dir string, cfg *config.Config, cache, idea bool) error {
 	}
 
 	for _, file := range files {
-		err = generator.StartFromDDL(file, cache)
+		err = generator.StartFromDDL(file, cache, database)
 		if err != nil {
 			return err
 		}

+ 4 - 4
tools/goctl/model/sql/command/command_test.go

@@ -24,12 +24,12 @@ func TestFromDDl(t *testing.T) {
 	err := gen.Clean()
 	assert.Nil(t, err)
 
-	err = fromDDl("./user.sql", t.TempDir(), cfg, true, false)
+	err = fromDDl("./user.sql", t.TempDir(), cfg, true, false, "go_zero")
 	assert.Equal(t, errNotMatched, err)
 
 	// case dir is not exists
 	unknownDir := filepath.Join(t.TempDir(), "test", "user.sql")
-	err = fromDDl(unknownDir, t.TempDir(), cfg, true, false)
+	err = fromDDl(unknownDir, t.TempDir(), cfg, true, false, "go_zero")
 	assert.True(t, func() bool {
 		switch err.(type) {
 		case *os.PathError:
@@ -40,7 +40,7 @@ func TestFromDDl(t *testing.T) {
 	}())
 
 	// case empty src
-	err = fromDDl("", t.TempDir(), cfg, true, false)
+	err = fromDDl("", t.TempDir(), cfg, true, false, "go_zero")
 	if err != nil {
 		assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error())
 	}
@@ -70,7 +70,7 @@ func TestFromDDl(t *testing.T) {
 	_, err = os.Stat(user2Sql)
 	assert.Nil(t, err)
 
-	err = fromDDl(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false)
+	err = fromDDl(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false, "go_zero")
 	assert.Nil(t, err)
 
 	_, err = os.Stat(filepath.Join(tempDir, "usermodel.go"))

+ 4 - 4
tools/goctl/model/sql/gen/gen.go

@@ -90,8 +90,8 @@ func newDefaultOption() Option {
 	}
 }
 
-func (g *defaultGenerator) StartFromDDL(filename string, withCache bool) error {
-	modelList, err := g.genFromDDL(filename, withCache)
+func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, database string) error {
+	modelList, err := g.genFromDDL(filename, withCache, database)
 	if err != nil {
 		return err
 	}
@@ -174,9 +174,9 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error {
 }
 
 // ret1: key-table name,value-code
-func (g *defaultGenerator) genFromDDL(filename string, withCache bool) (map[string]string, error) {
+func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) (map[string]string, error) {
 	m := make(map[string]string)
-	tables, err := parser.Parse(filename)
+	tables, err := parser.Parse(filename, database)
 	if err != nil {
 		return nil, err
 	}

+ 4 - 4
tools/goctl/model/sql/gen/gen_test.go

@@ -34,7 +34,7 @@ func TestCacheModel(t *testing.T) {
 	})
 	assert.Nil(t, err)
 
-	err = g.StartFromDDL(sqlFile, true)
+	err = g.StartFromDDL(sqlFile, true, "go_zero")
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 		_, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go"))
@@ -45,7 +45,7 @@ func TestCacheModel(t *testing.T) {
 	})
 	assert.Nil(t, err)
 
-	err = g.StartFromDDL(sqlFile, false)
+	err = g.StartFromDDL(sqlFile, false, "go_zero")
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 		_, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go"))
@@ -72,7 +72,7 @@ func TestNamingModel(t *testing.T) {
 	})
 	assert.Nil(t, err)
 
-	err = g.StartFromDDL(sqlFile, true)
+	err = g.StartFromDDL(sqlFile, true, "go_zero")
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 		_, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go"))
@@ -83,7 +83,7 @@ func TestNamingModel(t *testing.T) {
 	})
 	assert.Nil(t, err)
 
-	err = g.StartFromDDL(sqlFile, true)
+	err = g.StartFromDDL(sqlFile, true, "go_zero")
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 		_, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go"))

+ 6 - 6
tools/goctl/model/sql/gen/keys.go

@@ -39,9 +39,9 @@ type Join []string
 func genCacheKeys(table parser.Table) (Key, []Key) {
 	var primaryKey Key
 	var uniqueKey []Key
-	primaryKey = genCacheKey(table.Name, []*parser.Field{&table.PrimaryKey.Field})
+	primaryKey = genCacheKey(table.Db, table.Name, []*parser.Field{&table.PrimaryKey.Field})
 	for _, each := range table.UniqueIndex {
-		uniqueKey = append(uniqueKey, genCacheKey(table.Name, each))
+		uniqueKey = append(uniqueKey, genCacheKey(table.Db, table.Name, each))
 	}
 	sort.Slice(uniqueKey, func(i, j int) bool {
 		return uniqueKey[i].VarLeft < uniqueKey[j].VarLeft
@@ -50,7 +50,7 @@ func genCacheKeys(table parser.Table) (Key, []Key) {
 	return primaryKey, uniqueKey
 }
 
-func genCacheKey(table stringx.String, in []*parser.Field) Key {
+func genCacheKey(db stringx.String, table stringx.String, in []*parser.Field) Key {
 	var (
 		varLeftJoin, varRightJon, fieldNameJoin Join
 		varLeft, varRight, varExpression        string
@@ -59,9 +59,9 @@ func genCacheKey(table stringx.String, in []*parser.Field) Key {
 		keyLeft, keyRight, dataKeyRight, keyExpression, dataKeyExpression string
 	)
 
-	varLeftJoin = append(varLeftJoin, "cache", table.Source())
-	varRightJon = append(varRightJon, "cache", table.Source())
-	keyLeftJoin = append(keyLeftJoin, table.Source())
+	varLeftJoin = append(varLeftJoin, "cache", db.Source(), table.Source())
+	varRightJon = append(varRightJon, "cache", db.Source(), table.Source())
+	keyLeftJoin = append(keyLeftJoin, db.Source(), table.Source())
 
 	for _, each := range in {
 		varLeftJoin = append(varLeftJoin, each.Name.Source())

+ 25 - 24
tools/goctl/model/sql/gen/keys_test.go

@@ -36,6 +36,7 @@ func TestGenCacheKeys(t *testing.T) {
 	}
 	primariCacheKey, uniqueCacheKey := genCacheKeys(parser.Table{
 		Name: stringx.From("user"),
+		Db:   stringx.From("go_zero"),
 		PrimaryKey: parser.Primary{
 			Field:         *primaryField,
 			AutoIncrement: true,
@@ -70,14 +71,14 @@ func TestGenCacheKeys(t *testing.T) {
 	t.Run("primaryCacheKey", func(t *testing.T) {
 		assert.Equal(t, true, func() bool {
 			return cacheKeyEqual(primariCacheKey, Key{
-				VarLeft:           "cacheUserIdPrefix",
-				VarRight:          `"cache:user:id:"`,
-				VarExpression:     `cacheUserIdPrefix = "cache:user:id:"`,
-				KeyLeft:           "userIdKey",
-				KeyRight:          `fmt.Sprintf("%s%v", cacheUserIdPrefix, id)`,
-				DataKeyRight:      `fmt.Sprintf("%s%v", cacheUserIdPrefix, data.Id)`,
-				KeyExpression:     `userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, id)`,
-				DataKeyExpression: `userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, data.Id)`,
+				VarLeft:           "cacheGoZeroUserIdPrefix",
+				VarRight:          `"cache:goZero:user:id:"`,
+				VarExpression:     `cacheGoZeroUserIdPrefix = "cache:goZero:user:id:"`,
+				KeyLeft:           "goZeroUserIdKey",
+				KeyRight:          `fmt.Sprintf("%s%v", cacheGoZeroUserIdPrefix, id)`,
+				DataKeyRight:      `fmt.Sprintf("%s%v", cacheGoZeroUserIdPrefix, data.Id)`,
+				KeyExpression:     `goZeroUserIdKey := fmt.Sprintf("%s%v", cacheGoZeroUserIdPrefix, id)`,
+				DataKeyExpression: `goZeroUserIdKey := fmt.Sprintf("%s%v", cacheGoZeroUserIdPrefix, data.Id)`,
 				FieldNameJoin:     []string{"id"},
 			})
 		}())
@@ -87,25 +88,25 @@ func TestGenCacheKeys(t *testing.T) {
 		assert.Equal(t, true, func() bool {
 			expected := []Key{
 				{
-					VarLeft:           "cacheUserClassNamePrefix",
-					VarRight:          `"cache:user:class:name:"`,
-					VarExpression:     `cacheUserClassNamePrefix = "cache:user:class:name:"`,
-					KeyLeft:           "userClassNameKey",
-					KeyRight:          `fmt.Sprintf("%s%v:%v", cacheUserClassNamePrefix, class, name)`,
-					DataKeyRight:      `fmt.Sprintf("%s%v:%v", cacheUserClassNamePrefix, data.Class, data.Name)`,
-					KeyExpression:     `userClassNameKey := fmt.Sprintf("%s%v:%v", cacheUserClassNamePrefix, class, name)`,
-					DataKeyExpression: `userClassNameKey := fmt.Sprintf("%s%v:%v", cacheUserClassNamePrefix, data.Class, data.Name)`,
+					VarLeft:           "cacheGoZeroUserClassNamePrefix",
+					VarRight:          `"cache:goZero:user:class:name:"`,
+					VarExpression:     `cacheGoZeroUserClassNamePrefix = "cache:goZero:user:class:name:"`,
+					KeyLeft:           "goZeroUserClassNameKey",
+					KeyRight:          `fmt.Sprintf("%s%v:%v", cacheGoZeroUserClassNamePrefix, class, name)`,
+					DataKeyRight:      `fmt.Sprintf("%s%v:%v", cacheGoZeroUserClassNamePrefix, data.Class, data.Name)`,
+					KeyExpression:     `goZeroUserClassNameKey := fmt.Sprintf("%s%v:%v", cacheGoZeroUserClassNamePrefix, class, name)`,
+					DataKeyExpression: `goZeroUserClassNameKey := fmt.Sprintf("%s%v:%v", cacheGoZeroUserClassNamePrefix, data.Class, data.Name)`,
 					FieldNameJoin:     []string{"class", "name"},
 				},
 				{
-					VarLeft:           "cacheUserMobilePrefix",
-					VarRight:          `"cache:user:mobile:"`,
-					VarExpression:     `cacheUserMobilePrefix = "cache:user:mobile:"`,
-					KeyLeft:           "userMobileKey",
-					KeyRight:          `fmt.Sprintf("%s%v", cacheUserMobilePrefix, mobile)`,
-					DataKeyRight:      `fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)`,
-					KeyExpression:     `userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, mobile)`,
-					DataKeyExpression: `userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)`,
+					VarLeft:           "cacheGoZeroUserMobilePrefix",
+					VarRight:          `"cache:goZero:user:mobile:"`,
+					VarExpression:     `cacheGoZeroUserMobilePrefix = "cache:goZero:user:mobile:"`,
+					KeyLeft:           "goZeroUserMobileKey",
+					KeyRight:          `fmt.Sprintf("%s%v", cacheGoZeroUserMobilePrefix, mobile)`,
+					DataKeyRight:      `fmt.Sprintf("%s%v", cacheGoZeroUserMobilePrefix, data.Mobile)`,
+					KeyExpression:     `goZeroUserMobileKey := fmt.Sprintf("%s%v", cacheGoZeroUserMobilePrefix, mobile)`,
+					DataKeyExpression: `goZeroUserMobileKey := fmt.Sprintf("%s%v", cacheGoZeroUserMobilePrefix, data.Mobile)`,
 					FieldNameJoin:     []string{"mobile"},
 				},
 			}

+ 4 - 1
tools/goctl/model/sql/parser/parser.go

@@ -21,6 +21,7 @@ type (
 	// Table describes a mysql table
 	Table struct {
 		Name        stringx.String
+		Db          stringx.String
 		PrimaryKey  Primary
 		UniqueIndex map[string][]*Field
 		Fields      []*Field
@@ -46,7 +47,7 @@ type (
 )
 
 // Parse parses ddl into golang structure
-func Parse(filename string) ([]*Table, error) {
+func Parse(filename string, database string) ([]*Table, error) {
 	p := parser.NewParser()
 	tables, err := p.From(filename)
 	if err != nil {
@@ -145,6 +146,7 @@ func Parse(filename string) ([]*Table, error) {
 
 		list = append(list, &Table{
 			Name:        stringx.From(e.Name),
+			Db: 		 stringx.From(database),
 			PrimaryKey:  primaryKey,
 			UniqueIndex: uniqueIndex,
 			Fields:      fields,
@@ -243,6 +245,7 @@ func ConvertDataType(table *model.Table) (*Table, error) {
 	var reply Table
 	reply.UniqueIndex = map[string][]*Field{}
 	reply.Name = stringx.From(table.Table)
+	reply.Db = stringx.From(table.Db)
 	seqInIndex := 0
 	if table.PrimaryKey.Index != nil {
 		seqInIndex = table.PrimaryKey.Index.SeqInIndex

+ 3 - 3
tools/goctl/model/sql/parser/parser_test.go

@@ -15,7 +15,7 @@ func TestParsePlainText(t *testing.T) {
 	err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777)
 	assert.Nil(t, err)
 
-	_, err = Parse(sqlFile)
+	_, err = Parse(sqlFile, "go_zero")
 	assert.NotNil(t, err)
 }
 
@@ -24,7 +24,7 @@ func TestParseSelect(t *testing.T) {
 	err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777)
 	assert.Nil(t, err)
 
-	tables, err := Parse(sqlFile)
+	tables, err := Parse(sqlFile, "go_zero")
 	assert.Nil(t, err)
 	assert.Equal(t, 0, len(tables))
 }
@@ -34,7 +34,7 @@ func TestParseCreateTable(t *testing.T) {
 	err := ioutil.WriteFile(sqlFile, []byte("CREATE TABLE `test_user` (\n  `id` bigint NOT NULL AUTO_INCREMENT,\n  `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机  号',\n  `class` bigint NOT NULL comment '班级',\n  `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n  名',\n  `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n  `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n  PRIMARY KEY (`id`),\n  UNIQUE KEY `mobile_unique` (`mobile`),\n  UNIQUE KEY `class_name_unique` (`class`,`name`),\n  KEY `create_index` (`create_time`),\n  KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"), 0o777)
 	assert.Nil(t, err)
 
-	tables, err := Parse(sqlFile)
+	tables, err := Parse(sqlFile, "go_zero")
 	assert.Equal(t, 1, len(tables))
 	table := tables[0]
 	assert.Nil(t, err)