Browse Source

Optimize model naming (#910)

* add unit test

* fix #907

* format code

* format code

* format code

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
anqiansong 3 years ago
parent
commit
b2fea65faa

+ 2 - 2
tools/goctl/model/sql/command/command.go

@@ -50,7 +50,7 @@ func MysqlDDL(ctx *cli.Context) error {
 		return err
 	}
 
-	return fromDDl(src, dir, cfg, cache, idea, database)
+	return fromDDL(src, dir, cfg, cache, idea, database)
 }
 
 // MySqlDataSource generates model code from datasource
@@ -102,7 +102,7 @@ func PostgreSqlDataSource(ctx *cli.Context) error {
 	return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea)
 }
 
-func fromDDl(src, dir string, cfg *config.Config, cache, idea bool, database string) error {
+func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database string) error {
 	log := console.NewConsole(idea)
 	src = strings.TrimSpace(src)
 	if len(src) == 0 {

+ 16 - 7
tools/goctl/model/sql/command/command_test.go

@@ -24,12 +24,12 @@ func TestFromDDl(t *testing.T) {
 	err := gen.Clean()
 	assert.Nil(t, err)
 
-	err = fromDDl("./user.sql", t.TempDir(), cfg, true, false, "go_zero")
+	err = fromDDL("./user.sql", t.TempDir(), cfg, true, false, "go_zero")
 	assert.Equal(t, errNotMatched, err)
 
 	// case dir is not exists
 	unknownDir := filepath.Join(t.TempDir(), "test", "user.sql")
-	err = fromDDl(unknownDir, t.TempDir(), cfg, true, false, "go_zero")
+	err = fromDDL(unknownDir, t.TempDir(), cfg, true, false, "go_zero")
 	assert.True(t, func() bool {
 		switch err.(type) {
 		case *os.PathError:
@@ -40,7 +40,7 @@ func TestFromDDl(t *testing.T) {
 	}())
 
 	// case empty src
-	err = fromDDl("", t.TempDir(), cfg, true, false, "go_zero")
+	err = fromDDL("", t.TempDir(), cfg, true, false, "go_zero")
 	if err != nil {
 		assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error())
 	}
@@ -70,9 +70,18 @@ func TestFromDDl(t *testing.T) {
 	_, err = os.Stat(user2Sql)
 	assert.Nil(t, err)
 
-	err = fromDDl(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false, "go_zero")
-	assert.Nil(t, err)
+	filename := filepath.Join(tempDir, "usermodel.go")
+	fromDDL := func(db string) {
+		err = fromDDL(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false, db)
+		assert.Nil(t, err)
 
-	_, err = os.Stat(filepath.Join(tempDir, "usermodel.go"))
-	assert.Nil(t, err)
+		_, err = os.Stat(filename)
+		assert.Nil(t, err)
+	}
+
+	fromDDL("go_zero")
+	_ = os.Remove(filename)
+	fromDDL("go-zero")
+	_ = os.Remove(filename)
+	fromDDL("1gozero")
 }

+ 4 - 0
tools/goctl/model/sql/example/makefile

@@ -5,6 +5,10 @@ fromDDLWithCache:
 	goctl template clean
 	goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/cache" -cache
 
+fromDDLWithCacheAndDb:
+	goctl template clean
+	goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/cache_db" -database="1gozero" -cache
+
 fromDDLWithoutCache:
 	goctl template clean;
 	goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/nocache"

+ 1 - 1
tools/goctl/model/sql/gen/gen.go

@@ -146,7 +146,7 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error {
 			return err
 		}
 
-		name := modelFilename + ".go"
+		name := util.SafeString(modelFilename) + ".go"
 		filename := filepath.Join(dirAbs, name)
 		if util.FileExists(filename) {
 			g.Warning("%s already exists, ignored.", name)

+ 13 - 5
tools/goctl/model/sql/gen/keys.go

@@ -6,6 +6,7 @@ import (
 	"strings"
 
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
+	"github.com/tal-tech/go-zero/tools/goctl/util"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
 
@@ -59,9 +60,16 @@ func genCacheKey(db, table stringx.String, in []*parser.Field) Key {
 		keyLeft, keyRight, dataKeyRight, keyExpression, dataKeyExpression string
 	)
 
-	varLeftJoin = append(varLeftJoin, "cache", db.Source(), table.Source())
-	varRightJon = append(varRightJon, "cache", db.Source(), table.Source())
-	keyLeftJoin = append(keyLeftJoin, db.Source(), table.Source())
+	dbName, tableName := util.SafeString(db.Source()), util.SafeString(table.Source())
+	if len(dbName) > 0 {
+		varLeftJoin = append(varLeftJoin, "cache", dbName, tableName)
+		varRightJon = append(varRightJon, "cache", dbName, tableName)
+		keyLeftJoin = append(keyLeftJoin, dbName, tableName)
+	} else {
+		varLeftJoin = append(varLeftJoin, "cache", tableName)
+		varRightJon = append(varRightJon, "cache", tableName)
+		keyLeftJoin = append(keyLeftJoin, tableName)
+	}
 
 	for _, each := range in {
 		varLeftJoin = append(varLeftJoin, each.Name.Source())
@@ -75,11 +83,11 @@ func genCacheKey(db, table stringx.String, in []*parser.Field) Key {
 	varLeftJoin = append(varLeftJoin, "prefix")
 	keyLeftJoin = append(keyLeftJoin, "key")
 
-	varLeft = varLeftJoin.Camel().With("").Untitle()
+	varLeft = util.SafeString(varLeftJoin.Camel().With("").Untitle())
 	varRight = fmt.Sprintf(`"%s"`, varRightJon.Camel().Untitle().With(":").Source()+":")
 	varExpression = fmt.Sprintf(`%s = %s`, varLeft, varRight)
 
-	keyLeft = keyLeftJoin.Camel().With("").Untitle()
+	keyLeft = util.SafeString(keyLeftJoin.Camel().With("").Untitle())
 	keyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With(":").Source(), varLeft, keyRightJoin.With(", ").Source())
 	dataKeyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With(":").Source(), varLeft, dataRightJoin.With(", ").Source())
 	keyExpression = fmt.Sprintf("%s := %s", keyLeft, keyRight)

+ 96 - 0
tools/goctl/util/env/env_test.go

@@ -0,0 +1,96 @@
+package env
+
+import (
+	"bytes"
+	"fmt"
+	"os/exec"
+	"runtime"
+	"strings"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/vars"
+)
+
+func TestLookUpGo(t *testing.T) {
+	xGo, err := LookUpGo()
+	if err != nil {
+		return
+	}
+
+	assert.True(t, util.FileExists(xGo))
+	output, errOutput, err := execCommand(xGo, "version")
+	if err != nil {
+		return
+	}
+
+	if len(errOutput) > 0 {
+		return
+	}
+	assert.Equal(t, wrapVersion(), output)
+}
+
+func TestLookUpProtoc(t *testing.T) {
+	xProtoc, err := LookUpProtoc()
+	if err != nil {
+		return
+	}
+
+	assert.True(t, util.FileExists(xProtoc))
+	output, errOutput, err := execCommand(xProtoc, "--version")
+	if err != nil {
+		return
+	}
+
+	if len(errOutput) > 0 {
+		return
+	}
+	assert.True(t, len(output) > 0)
+}
+
+func TestLookUpProtocGenGo(t *testing.T) {
+	xProtocGenGo, err := LookUpProtocGenGo()
+	if err != nil {
+		return
+	}
+	assert.True(t, util.FileExists(xProtocGenGo))
+}
+
+func TestLookPath(t *testing.T) {
+	xGo, err := LookPath("go")
+	if err != nil {
+		return
+	}
+	assert.True(t, util.FileExists(xGo))
+}
+
+func TestCanExec(t *testing.T) {
+	canExec := runtime.GOOS != vars.OsJs && runtime.GOOS != vars.OsIOS
+	assert.Equal(t, canExec, CanExec())
+}
+
+func execCommand(cmd string, arg ...string) (stdout string, stderr string, err error) {
+	output := bytes.NewBuffer(nil)
+	errOutput := bytes.NewBuffer(nil)
+	c := exec.Command(cmd, arg...)
+	c.Stdout = output
+	c.Stderr = errOutput
+	err = c.Run()
+	if err != nil {
+		return
+	}
+	if errOutput.Len() > 0 {
+		stderr = errOutput.String()
+		return
+	}
+	stdout = strings.TrimSpace(output.String())
+	return
+}
+
+func wrapVersion() string {
+	version := runtime.Version()
+	os := runtime.GOOS
+	arch := runtime.GOARCH
+	return fmt.Sprintf("go version %s %s/%s", version, os, arch)
+}

+ 32 - 0
tools/goctl/util/string.go

@@ -32,3 +32,35 @@ func Index(slice []string, item string) int {
 
 	return -1
 }
+
+// SafeString converts the input string into a safe naming style in golang
+func SafeString(in string) string {
+	if len(in) == 0 {
+		return in
+	}
+
+	data := strings.Map(func(r rune) rune {
+		if isSafeRune(r) {
+			return r
+		}
+		return '_'
+	}, in)
+
+	headRune := rune(data[0])
+	if isNumber(headRune) {
+		return "_" + data
+	}
+	return data
+}
+
+func isSafeRune(r rune) bool {
+	return isLetter(r) || isNumber(r) || r == '_'
+}
+
+func isLetter(r rune) bool {
+	return 'A' <= r && r <= 'z'
+}
+
+func isNumber(r rune) bool {
+	return '0' <= r && r <= '9'
+}

+ 66 - 0
tools/goctl/util/string_test.go

@@ -0,0 +1,66 @@
+package util
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+type data struct {
+	input    string
+	expected string
+}
+
+func TestTitle(t *testing.T) {
+	list := []*data{
+		{input: "_", expected: "_"},
+		{input: "abc", expected: "Abc"},
+		{input: "ABC", expected: "ABC"},
+		{input: "", expected: ""},
+		{input: " abc", expected: " abc"},
+	}
+	for _, e := range list {
+		assert.Equal(t, e.expected, Title(e.input))
+	}
+}
+
+func TestUntitle(t *testing.T) {
+	list := []*data{
+		{input: "_", expected: "_"},
+		{input: "Abc", expected: "abc"},
+		{input: "ABC", expected: "aBC"},
+		{input: "", expected: ""},
+		{input: " abc", expected: " abc"},
+	}
+
+	for _, e := range list {
+		assert.Equal(t, e.expected, Untitle(e.input))
+	}
+}
+
+func TestIndex(t *testing.T) {
+	list := []string{"a", "b", "c"}
+	assert.Equal(t, 1, Index(list, "b"))
+	assert.Equal(t, -1, Index(list, "d"))
+}
+
+func TestSafeString(t *testing.T) {
+	list := []*data{
+		{input: "_", expected: "_"},
+		{input: "a-b-c", expected: "a_b_c"},
+		{input: "123abc", expected: "_123abc"},
+		{input: "汉abc", expected: "_abc"},
+		{input: "汉a字", expected: "_a_"},
+		{input: "キャラクターabc", expected: "______abc"},
+		{input: "-a_B-C", expected: "_a_B_C"},
+		{input: "a_B C", expected: "a_B_C"},
+		{input: "A#B#C", expected: "A_B_C"},
+		{input: "_123", expected: "_123"},
+		{input: "", expected: ""},
+		{input: "\t", expected: "_"},
+		{input: "\n", expected: "_"},
+	}
+	for _, e := range list {
+		assert.Equal(t, e.expected, SafeString(e.input))
+	}
+}