Browse Source

patch model&rpc (#207)

* change column to read from information_schema

* reactor generate mode from datasource

* reactor generate mode from datasource

* add primary key check logic

* resolve rebase conflicts

* add naming style

* add filename test case

* resolve rebase conflicts

* reactor test

* add test case

* change shell script to makefile

* update rpc new

* update gen_test.go

* format code

* format code

* update test

* generates alias
Keson 4 years ago
parent
commit
24fb29a356
55 changed files with 678 additions and 1167 deletions
  1. 9 1
      tools/goctl/goctl.go
  2. 0 27
      tools/goctl/model/sql/builderx/builder.go
  3. 38 26
      tools/goctl/model/sql/command/command.go
  4. 75 0
      tools/goctl/model/sql/command/command_test.go
  5. 0 11
      tools/goctl/model/sql/example/generator.sh
  6. 15 0
      tools/goctl/model/sql/example/makefile
  7. 1 0
      tools/goctl/model/sql/gen/delete.go
  8. 2 0
      tools/goctl/model/sql/gen/field.go
  9. 1 0
      tools/goctl/model/sql/gen/findone.go
  10. 57 12
      tools/goctl/model/sql/gen/gen.go
  11. 17 9
      tools/goctl/model/sql/gen/gen_test.go
  12. 0 5
      tools/goctl/model/sql/gen/keys_test.go
  13. 2 3
      tools/goctl/model/sql/gen/split.go
  14. 1 0
      tools/goctl/model/sql/gen/tag.go
  15. 1 0
      tools/goctl/model/sql/model/ddlmodel.go
  16. 15 0
      tools/goctl/model/sql/model/informationschemamodel.go
  17. 61 2
      tools/goctl/model/sql/parser/parser.go
  18. 56 0
      tools/goctl/model/sql/parser/parser_test.go
  19. 15 9
      tools/goctl/rpc/cli/cli.go
  20. 0 75
      tools/goctl/rpc/generator/base/common.pb.go
  21. 9 2
      tools/goctl/rpc/generator/filename.go
  22. 17 0
      tools/goctl/rpc/generator/filename_test.go
  23. 15 13
      tools/goctl/rpc/generator/gen.go
  24. 50 104
      tools/goctl/rpc/generator/gen_test.go
  25. 5 6
      tools/goctl/rpc/generator/gencall.go
  26. 0 44
      tools/goctl/rpc/generator/gencall_test.go
  27. 2 2
      tools/goctl/rpc/generator/genconfig.go
  28. 0 48
      tools/goctl/rpc/generator/genconfig_test.go
  29. 8 8
      tools/goctl/rpc/generator/generator.go
  30. 5 3
      tools/goctl/rpc/generator/genetc.go
  31. 0 45
      tools/goctl/rpc/generator/genetc_test.go
  32. 2 2
      tools/goctl/rpc/generator/genlogic.go
  33. 0 44
      tools/goctl/rpc/generator/genlogic_test.go
  34. 3 3
      tools/goctl/rpc/generator/genmain.go
  35. 0 45
      tools/goctl/rpc/generator/genmain_test.go
  36. 1 1
      tools/goctl/rpc/generator/genpb.go
  37. 0 184
      tools/goctl/rpc/generator/genpb_test.go
  38. 2 2
      tools/goctl/rpc/generator/genserver.go
  39. 0 45
      tools/goctl/rpc/generator/genserver_test.go
  40. 2 2
      tools/goctl/rpc/generator/gensvc.go
  41. 0 40
      tools/goctl/rpc/generator/gensvc_test.go
  42. 0 130
      tools/goctl/rpc/generator/mkdir_test.go
  43. 24 0
      tools/goctl/rpc/generator/naming.go
  44. 25 0
      tools/goctl/rpc/generator/naming_test.go
  45. 7 8
      tools/goctl/rpc/generator/prototmpl_test.go
  46. 83 65
      tools/goctl/rpc/generator/template_test.go
  47. 50 13
      tools/goctl/rpc/generator/test.proto
  48. 0 12
      tools/goctl/rpc/generator/test_base.proto
  49. 0 18
      tools/goctl/rpc/generator/test_go_option.proto
  50. 0 18
      tools/goctl/rpc/generator/test_import.proto
  51. 0 18
      tools/goctl/rpc/generator/test_option.proto
  52. 0 27
      tools/goctl/rpc/generator/test_service_rpc_naming_snake.proto
  53. 0 17
      tools/goctl/rpc/generator/test_stream.proto
  54. 0 18
      tools/goctl/rpc/generator/test_word_option.proto
  55. 2 0
      tools/goctl/util/stringx/string.go

+ 9 - 1
tools/goctl/goctl.go

@@ -201,6 +201,10 @@ var (
 					Name:  "new",
 					Name:  "new",
 					Usage: `generate rpc demo service`,
 					Usage: `generate rpc demo service`,
 					Flags: []cli.Flag{
 					Flags: []cli.Flag{
+						cli.StringFlag{
+							Name:  "style",
+							Usage: "the file naming style, lower|camel|snake,default is lower",
+						},
 						cli.BoolFlag{
 						cli.BoolFlag{
 							Name:  "idea",
 							Name:  "idea",
 							Usage: "whether the command execution environment is from idea plugin. [optional]",
 							Usage: "whether the command execution environment is from idea plugin. [optional]",
@@ -235,6 +239,10 @@ var (
 							Name:  "dir, d",
 							Name:  "dir, d",
 							Usage: `the target path of the code`,
 							Usage: `the target path of the code`,
 						},
 						},
+						cli.StringFlag{
+							Name:  "style",
+							Usage: "the file naming style, lower|camel|snake,default is lower",
+						},
 						cli.BoolFlag{
 						cli.BoolFlag{
 							Name:  "idea",
 							Name:  "idea",
 							Usage: "whether the command execution environment is from idea plugin. [optional]",
 							Usage: "whether the command execution environment is from idea plugin. [optional]",
@@ -266,7 +274,7 @@ var (
 								},
 								},
 								cli.StringFlag{
 								cli.StringFlag{
 									Name:  "style",
 									Name:  "style",
-									Usage: "the file naming style, lower|camel|underline,default is lower",
+									Usage: "the file naming style, lower|camel|snake,default is lower",
 								},
 								},
 								cli.BoolFlag{
 								cli.BoolFlag{
 									Name:  "cache, c",
 									Name:  "cache, c",

+ 0 - 27
tools/goctl/model/sql/builderx/builder.go

@@ -68,30 +68,3 @@ func FieldNames(in interface{}) []string {
 	}
 	}
 	return out
 	return out
 }
 }
-func FieldNamesAlias(in interface{}, alias string) []string {
-	out := make([]string, 0)
-	v := reflect.ValueOf(in)
-	if v.Kind() == reflect.Ptr {
-		v = v.Elem()
-	}
-	// we only accept structs
-	if v.Kind() != reflect.Struct {
-		panic(fmt.Errorf("ToMap only accepts structs; got %T", v))
-	}
-	typ := v.Type()
-	for i := 0; i < v.NumField(); i++ {
-		// gets us a StructField
-		fi := typ.Field(i)
-		tagName := ""
-		if tagv := fi.Tag.Get(dbTag); tagv != "" {
-			tagName = tagv
-		} else {
-			tagName = fi.Name
-		}
-		if len(alias) > 0 {
-			tagName = alias + "." + tagName
-		}
-		out = append(out, tagName)
-	}
-	return out
-}

+ 38 - 26
tools/goctl/model/sql/command/command.go

@@ -17,6 +17,8 @@ import (
 	"github.com/urfave/cli"
 	"github.com/urfave/cli"
 )
 )
 
 
+var errNotMatched = errors.New("sql not matched")
+
 const (
 const (
 	flagSrc   = "src"
 	flagSrc   = "src"
 	flagDir   = "dir"
 	flagDir   = "dir"
@@ -33,6 +35,20 @@ func MysqlDDL(ctx *cli.Context) error {
 	cache := ctx.Bool(flagCache)
 	cache := ctx.Bool(flagCache)
 	idea := ctx.Bool(flagIdea)
 	idea := ctx.Bool(flagIdea)
 	namingStyle := strings.TrimSpace(ctx.String(flagStyle))
 	namingStyle := strings.TrimSpace(ctx.String(flagStyle))
+	return fromDDl(src, dir, namingStyle, cache, idea)
+}
+
+func MyDataSource(ctx *cli.Context) error {
+	url := strings.TrimSpace(ctx.String(flagUrl))
+	dir := strings.TrimSpace(ctx.String(flagDir))
+	cache := ctx.Bool(flagCache)
+	idea := ctx.Bool(flagIdea)
+	namingStyle := strings.TrimSpace(ctx.String(flagStyle))
+	pattern := strings.TrimSpace(ctx.String(flagTable))
+	return fromDataSource(url, pattern, dir, namingStyle, cache, idea)
+}
+
+func fromDDl(src, dir, namingStyle string, cache, idea bool) error {
 	log := console.NewConsole(idea)
 	log := console.NewConsole(idea)
 	src = strings.TrimSpace(src)
 	src = strings.TrimSpace(src)
 	if len(src) == 0 {
 	if len(src) == 0 {
@@ -52,29 +68,29 @@ func MysqlDDL(ctx *cli.Context) error {
 		return err
 		return err
 	}
 	}
 
 
+	if len(files) == 0 {
+		return errNotMatched
+	}
+
 	var source []string
 	var source []string
 	for _, file := range files {
 	for _, file := range files {
 		data, err := ioutil.ReadFile(file)
 		data, err := ioutil.ReadFile(file)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
+
 		source = append(source, string(data))
 		source = append(source, string(data))
 	}
 	}
-	generator := gen.NewDefaultGenerator(strings.Join(source, "\n"), dir, namingStyle, gen.WithConsoleOption(log))
-	err = generator.Start(cache)
+	generator, err := gen.NewDefaultGenerator(dir, namingStyle, gen.WithConsoleOption(log))
 	if err != nil {
 	if err != nil {
-		log.Error("%v", err)
+		return err
 	}
 	}
-	return nil
+
+	err = generator.StartFromDDL(strings.Join(source, "\n"), cache)
+	return err
 }
 }
 
 
-func MyDataSource(ctx *cli.Context) error {
-	url := strings.TrimSpace(ctx.String(flagUrl))
-	dir := strings.TrimSpace(ctx.String(flagDir))
-	cache := ctx.Bool(flagCache)
-	idea := ctx.Bool(flagIdea)
-	namingStyle := strings.TrimSpace(ctx.String(flagStyle))
-	pattern := strings.TrimSpace(ctx.String(flagTable))
+func fromDataSource(url, pattern, dir, namingStyle string, cache, idea bool) error {
 	log := console.NewConsole(idea)
 	log := console.NewConsole(idea)
 	if len(url) == 0 {
 	if len(url) == 0 {
 		log.Error("%v", "expected data source of mysql, but nothing found")
 		log.Error("%v", "expected data source of mysql, but nothing found")
@@ -100,10 +116,8 @@ func MyDataSource(ctx *cli.Context) error {
 	}
 	}
 
 
 	logx.Disable()
 	logx.Disable()
-	conn := sqlx.NewMysql(url)
 	databaseSource := strings.TrimSuffix(url, "/"+cfg.DBName) + "/information_schema"
 	databaseSource := strings.TrimSuffix(url, "/"+cfg.DBName) + "/information_schema"
 	db := sqlx.NewMysql(databaseSource)
 	db := sqlx.NewMysql(databaseSource)
-	m := model.NewDDLModel(conn)
 	im := model.NewInformationSchemaModel(db)
 	im := model.NewInformationSchemaModel(db)
 
 
 	tables, err := im.GetAllTables(cfg.DBName)
 	tables, err := im.GetAllTables(cfg.DBName)
@@ -111,7 +125,7 @@ func MyDataSource(ctx *cli.Context) error {
 		return err
 		return err
 	}
 	}
 
 
-	var matchTables []string
+	matchTables := make(map[string][]*model.Column)
 	for _, item := range tables {
 	for _, item := range tables {
 		match, err := filepath.Match(pattern, item)
 		match, err := filepath.Match(pattern, item)
 		if err != nil {
 		if err != nil {
@@ -121,24 +135,22 @@ func MyDataSource(ctx *cli.Context) error {
 		if !match {
 		if !match {
 			continue
 			continue
 		}
 		}
-
-		matchTables = append(matchTables, item)
+		columns, err := im.FindByTableName(cfg.DBName, item)
+		if err != nil {
+			return err
+		}
+		matchTables[item] = columns
 	}
 	}
+
 	if len(matchTables) == 0 {
 	if len(matchTables) == 0 {
 		return errors.New("no tables matched")
 		return errors.New("no tables matched")
 	}
 	}
 
 
-	ddl, err := m.ShowDDL(matchTables...)
-	if err != nil {
-		log.Error("%v", err)
-		return nil
-	}
-
-	generator := gen.NewDefaultGenerator(strings.Join(ddl, "\n"), dir, namingStyle, gen.WithConsoleOption(log))
-	err = generator.Start(cache)
+	generator, err := gen.NewDefaultGenerator(dir, namingStyle, gen.WithConsoleOption(log))
 	if err != nil {
 	if err != nil {
-		log.Error("%v", err)
+		return err
 	}
 	}
 
 
-	return nil
+	err = generator.StartFromInformationSchema(cfg.DBName, matchTables, cache)
+	return err
 }
 }

+ 75 - 0
tools/goctl/model/sql/command/command_test.go

@@ -0,0 +1,75 @@
+package command
+
+import (
+	"io/ioutil"
+	"os"
+	"path/filepath"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
+	"github.com/tal-tech/go-zero/tools/goctl/util"
+)
+
+var sql = "-- 用户表 --\nCREATE TABLE `user` (\n  `id` bigint(10) NOT NULL AUTO_INCREMENT,\n  `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n  `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n  `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n  `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n  `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n  `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n  `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n  PRIMARY KEY (`id`),\n  UNIQUE KEY `name_index` (`name`),\n  UNIQUE KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;\n\n"
+
+func TestFromDDl(t *testing.T) {
+	err := fromDDl("./user.sql", t.TempDir(), gen.NamingCamel, true, false)
+	assert.Equal(t, errNotMatched, err)
+
+	// case dir is not exists
+	unknownDir := filepath.Join(t.TempDir(), "test", "user.sql")
+	err = fromDDl(unknownDir, t.TempDir(), gen.NamingCamel, true, false)
+	assert.True(t, func() bool {
+		switch err.(type) {
+		case *os.PathError:
+			return true
+		default:
+			return false
+		}
+	}())
+
+	// case empty src
+	err = fromDDl("", t.TempDir(), gen.NamingCamel, true, false)
+	if err != nil {
+		assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error())
+	}
+
+	// case unknown naming style
+	tmp := filepath.Join(t.TempDir(), "user.sql")
+	err = fromDDl(tmp, t.TempDir(), "lower1", true, false)
+	if err != nil {
+		assert.Equal(t, "unexpected naming style: lower1", err.Error())
+	}
+
+	tempDir := filepath.Join(t.TempDir(), "test")
+	err = util.MkdirIfNotExist(tempDir)
+	if err != nil {
+		return
+	}
+
+	user1Sql := filepath.Join(tempDir, "user1.sql")
+	user2Sql := filepath.Join(tempDir, "user2.sql")
+
+	err = ioutil.WriteFile(user1Sql, []byte(sql), os.ModePerm)
+	if err != nil {
+		return
+	}
+
+	err = ioutil.WriteFile(user2Sql, []byte(sql), os.ModePerm)
+	if err != nil {
+		return
+	}
+
+	_, err = os.Stat(user1Sql)
+	assert.Nil(t, err)
+
+	_, err = os.Stat(user2Sql)
+	assert.Nil(t, err)
+
+	err = fromDDl(filepath.Join(tempDir, "user*.sql"), tempDir, gen.NamingLower, true, false)
+	assert.Nil(t, err)
+
+	_, err = os.Stat(filepath.Join(tempDir, "usermodel.go"))
+	assert.Nil(t, err)
+}

+ 0 - 11
tools/goctl/model/sql/example/generator.sh

@@ -1,11 +0,0 @@
-#!/bin/bash
-
-# generate model with cache from ddl
-goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/user" -c
-
-# generate model with cache from data source
-#user=root
-#password=password
-#datasource=127.0.0.1:3306
-#database=test
-#goctl model mysql datasource -url="${user}:${password}@tcp(${datasource})/${database}" -table="*" -dir ./model

+ 15 - 0
tools/goctl/model/sql/example/makefile

@@ -0,0 +1,15 @@
+#!/bin/bash
+
+# generate model with cache from ddl
+fromDDL:
+	goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/user" -c
+
+
+# generate model with cache from data source
+user=root
+password=password
+datasource=127.0.0.1:3306
+database=gozero
+
+fromDataSource:
+	goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style camel

+ 1 - 0
tools/goctl/model/sql/gen/delete.go

@@ -42,5 +42,6 @@ func genDelete(table Table, withCache bool) (string, error) {
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
+
 	return output.String(), nil
 	return output.String(), nil
 }
 }

+ 2 - 0
tools/goctl/model/sql/gen/field.go

@@ -15,6 +15,7 @@ func genFields(fields []parser.Field) (string, error) {
 		if err != nil {
 		if err != nil {
 			return "", err
 			return "", err
 		}
 		}
+
 		list = append(list, result)
 		list = append(list, result)
 	}
 	}
 	return strings.Join(list, "\n"), nil
 	return strings.Join(list, "\n"), nil
@@ -43,5 +44,6 @@ func genField(field parser.Field) (string, error) {
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
+
 	return output.String(), nil
 	return output.String(), nil
 }
 }

+ 1 - 0
tools/goctl/model/sql/gen/findone.go

@@ -28,5 +28,6 @@ func genFindOne(table Table, withCache bool) (string, error) {
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
+
 	return output.String(), nil
 	return output.String(), nil
 }
 }

+ 57 - 12
tools/goctl/model/sql/gen/gen.go

@@ -7,6 +7,7 @@ import (
 	"path/filepath"
 	"path/filepath"
 	"strings"
 	"strings"
 
 
+	"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
@@ -24,8 +25,8 @@ const (
 
 
 type (
 type (
 	defaultGenerator struct {
 	defaultGenerator struct {
-		source string
-		dir    string
+		//source string
+		dir string
 		console.Console
 		console.Console
 		pkg         string
 		pkg         string
 		namingStyle string
 		namingStyle string
@@ -33,18 +34,30 @@ type (
 	Option func(generator *defaultGenerator)
 	Option func(generator *defaultGenerator)
 )
 )
 
 
-func NewDefaultGenerator(source, dir, namingStyle string, opt ...Option) *defaultGenerator {
+func NewDefaultGenerator(dir, namingStyle string, opt ...Option) (*defaultGenerator, error) {
 	if dir == "" {
 	if dir == "" {
 		dir = pwd
 		dir = pwd
 	}
 	}
-	generator := &defaultGenerator{source: source, dir: dir, namingStyle: namingStyle}
+	dirAbs, err := filepath.Abs(dir)
+	if err != nil {
+		return nil, err
+	}
+
+	dir = dirAbs
+	pkg := filepath.Base(dirAbs)
+	err = util.MkdirIfNotExist(dir)
+	if err != nil {
+		return nil, err
+	}
+
+	generator := &defaultGenerator{dir: dir, namingStyle: namingStyle, pkg: pkg}
 	var optionList []Option
 	var optionList []Option
 	optionList = append(optionList, newDefaultOption())
 	optionList = append(optionList, newDefaultOption())
 	optionList = append(optionList, opt...)
 	optionList = append(optionList, opt...)
 	for _, fn := range optionList {
 	for _, fn := range optionList {
 		fn(generator)
 		fn(generator)
 	}
 	}
-	return generator
+	return generator, nil
 }
 }
 
 
 func WithConsoleOption(c console.Console) Option {
 func WithConsoleOption(c console.Console) Option {
@@ -59,21 +72,45 @@ func newDefaultOption() Option {
 	}
 	}
 }
 }
 
 
-func (g *defaultGenerator) Start(withCache bool) error {
+func (g *defaultGenerator) StartFromDDL(source string, withCache bool) error {
+	modelList, err := g.genFromDDL(source, withCache)
+	if err != nil {
+		return err
+	}
+
+	return g.createFile(modelList)
+}
+
+func (g *defaultGenerator) StartFromInformationSchema(db string, columns map[string][]*model.Column, withCache bool) error {
+	m := make(map[string]string)
+	for tableName, column := range columns {
+		table, err := parser.ConvertColumn(db, tableName, column)
+		if err != nil {
+			return err
+		}
+
+		code, err := g.genModel(*table, withCache)
+		if err != nil {
+			return err
+		}
+
+		m[table.Name.Source()] = code
+	}
+	return g.createFile(m)
+}
+
+func (g *defaultGenerator) createFile(modelList map[string]string) error {
 	dirAbs, err := filepath.Abs(g.dir)
 	dirAbs, err := filepath.Abs(g.dir)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+
 	g.dir = dirAbs
 	g.dir = dirAbs
 	g.pkg = filepath.Base(dirAbs)
 	g.pkg = filepath.Base(dirAbs)
 	err = util.MkdirIfNotExist(dirAbs)
 	err = util.MkdirIfNotExist(dirAbs)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	modelList, err := g.genFromDDL(withCache)
-	if err != nil {
-		return err
-	}
 
 
 	for tableName, code := range modelList {
 	for tableName, code := range modelList {
 		tn := stringx.From(tableName)
 		tn := stringx.From(tableName)
@@ -96,6 +133,9 @@ func (g *defaultGenerator) Start(withCache bool) error {
 	}
 	}
 	// generate error file
 	// generate error file
 	filename := filepath.Join(dirAbs, "vars.go")
 	filename := filepath.Join(dirAbs, "vars.go")
+	if g.namingStyle == NamingCamel {
+		filename = filepath.Join(dirAbs, "Vars.go")
+	}
 	text, err := util.LoadTemplate(category, errTemplateFile, template.Error)
 	text, err := util.LoadTemplate(category, errTemplateFile, template.Error)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -113,8 +153,8 @@ func (g *defaultGenerator) Start(withCache bool) error {
 }
 }
 
 
 // ret1: key-table name,value-code
 // ret1: key-table name,value-code
-func (g *defaultGenerator) genFromDDL(withCache bool) (map[string]string, error) {
-	ddlList := g.split()
+func (g *defaultGenerator) genFromDDL(source string, withCache bool) (map[string]string, error) {
+	ddlList := g.split(source)
 	m := make(map[string]string)
 	m := make(map[string]string)
 	for _, ddl := range ddlList {
 	for _, ddl := range ddlList {
 		table, err := parser.Parse(ddl)
 		table, err := parser.Parse(ddl)
@@ -139,10 +179,15 @@ type (
 )
 )
 
 
 func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) {
 func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) {
+	if len(in.PrimaryKey.Name.Source()) == 0 {
+		return "", fmt.Errorf("table %s: missing primary key", in.Name.Source())
+	}
+
 	text, err := util.LoadTemplate(category, modelTemplateFile, template.Model)
 	text, err := util.LoadTemplate(category, modelTemplateFile, template.Model)
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
+
 	t := util.With("model").
 	t := util.With("model").
 		Parse(text).
 		Parse(text).
 		GoFmt(true)
 		GoFmt(true)

+ 17 - 9
tools/goctl/model/sql/gen/gen_test.go

@@ -22,15 +22,19 @@ func TestCacheModel(t *testing.T) {
 	defer func() {
 	defer func() {
 		_ = os.RemoveAll(dir)
 		_ = os.RemoveAll(dir)
 	}()
 	}()
-	g := NewDefaultGenerator(source, cacheDir, NamingLower)
-	err := g.Start(true)
+	g, err := NewDefaultGenerator(cacheDir, NamingCamel)
+	assert.Nil(t, err)
+
+	err = g.StartFromDDL(source, true)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 	assert.True(t, func() bool {
-		_, err := os.Stat(filepath.Join(cacheDir, "testuserinfomodel.go"))
+		_, err := os.Stat(filepath.Join(cacheDir, "TestUserInfoModel.go"))
 		return err == nil
 		return err == nil
 	}())
 	}())
-	g = NewDefaultGenerator(source, noCacheDir, NamingLower)
-	err = g.Start(false)
+	g, err = NewDefaultGenerator(noCacheDir, NamingLower)
+	assert.Nil(t, err)
+
+	err = g.StartFromDDL(source, false)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 	assert.True(t, func() bool {
 		_, err := os.Stat(filepath.Join(noCacheDir, "testuserinfomodel.go"))
 		_, err := os.Stat(filepath.Join(noCacheDir, "testuserinfomodel.go"))
@@ -47,15 +51,19 @@ func TestNamingModel(t *testing.T) {
 	defer func() {
 	defer func() {
 		_ = os.RemoveAll(dir)
 		_ = os.RemoveAll(dir)
 	}()
 	}()
-	g := NewDefaultGenerator(source, camelDir, NamingCamel)
-	err := g.Start(true)
+	g, err := NewDefaultGenerator(camelDir, NamingCamel)
+	assert.Nil(t, err)
+
+	err = g.StartFromDDL(source, true)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 	assert.True(t, func() bool {
 		_, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go"))
 		_, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go"))
 		return err == nil
 		return err == nil
 	}())
 	}())
-	g = NewDefaultGenerator(source, snakeDir, NamingSnake)
-	err = g.Start(true)
+	g, err = NewDefaultGenerator(snakeDir, NamingSnake)
+	assert.Nil(t, err)
+
+	err = g.StartFromDDL(source, true)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.True(t, func() bool {
 	assert.True(t, func() bool {
 		_, err := os.Stat(filepath.Join(snakeDir, "test_user_info_model.go"))
 		_, err := os.Stat(filepath.Join(snakeDir, "test_user_info_model.go"))

+ 0 - 5
tools/goctl/model/sql/gen/keys_test.go

@@ -17,7 +17,6 @@ func TestGenCacheKeys(t *testing.T) {
 				Name:         stringx.From("id"),
 				Name:         stringx.From("id"),
 				DataBaseType: "bigint",
 				DataBaseType: "bigint",
 				DataType:     "int64",
 				DataType:     "int64",
-				IsKey:        false,
 				IsPrimaryKey: true,
 				IsPrimaryKey: true,
 				IsUniqueKey:  false,
 				IsUniqueKey:  false,
 				Comment:      "自增id",
 				Comment:      "自增id",
@@ -29,7 +28,6 @@ func TestGenCacheKeys(t *testing.T) {
 				Name:         stringx.From("mobile"),
 				Name:         stringx.From("mobile"),
 				DataBaseType: "varchar",
 				DataBaseType: "varchar",
 				DataType:     "string",
 				DataType:     "string",
-				IsKey:        false,
 				IsPrimaryKey: false,
 				IsPrimaryKey: false,
 				IsUniqueKey:  true,
 				IsUniqueKey:  true,
 				Comment:      "手机号",
 				Comment:      "手机号",
@@ -38,7 +36,6 @@ func TestGenCacheKeys(t *testing.T) {
 				Name:         stringx.From("name"),
 				Name:         stringx.From("name"),
 				DataBaseType: "varchar",
 				DataBaseType: "varchar",
 				DataType:     "string",
 				DataType:     "string",
-				IsKey:        false,
 				IsPrimaryKey: false,
 				IsPrimaryKey: false,
 				IsUniqueKey:  true,
 				IsUniqueKey:  true,
 				Comment:      "姓名",
 				Comment:      "姓名",
@@ -47,7 +44,6 @@ func TestGenCacheKeys(t *testing.T) {
 				Name:         stringx.From("createTime"),
 				Name:         stringx.From("createTime"),
 				DataBaseType: "timestamp",
 				DataBaseType: "timestamp",
 				DataType:     "time.Time",
 				DataType:     "time.Time",
-				IsKey:        false,
 				IsPrimaryKey: false,
 				IsPrimaryKey: false,
 				IsUniqueKey:  false,
 				IsUniqueKey:  false,
 				Comment:      "创建时间",
 				Comment:      "创建时间",
@@ -56,7 +52,6 @@ func TestGenCacheKeys(t *testing.T) {
 				Name:         stringx.From("updateTime"),
 				Name:         stringx.From("updateTime"),
 				DataBaseType: "timestamp",
 				DataBaseType: "timestamp",
 				DataType:     "time.Time",
 				DataType:     "time.Time",
-				IsKey:        false,
 				IsPrimaryKey: false,
 				IsPrimaryKey: false,
 				IsUniqueKey:  false,
 				IsUniqueKey:  false,
 				Comment:      "更新时间",
 				Comment:      "更新时间",

+ 2 - 3
tools/goctl/model/sql/gen/split.go

@@ -4,11 +4,10 @@ import (
 	"regexp"
 	"regexp"
 )
 )
 
 
-func (g *defaultGenerator) split() []string {
+func (g *defaultGenerator) split(source string) []string {
 	reg := regexp.MustCompile(createTableFlag)
 	reg := regexp.MustCompile(createTableFlag)
-	index := reg.FindAllStringIndex(g.source, -1)
+	index := reg.FindAllStringIndex(source, -1)
 	list := make([]string, 0)
 	list := make([]string, 0)
-	source := g.source
 	for i := len(index) - 1; i >= 0; i-- {
 	for i := len(index) - 1; i >= 0; i-- {
 		subIndex := index[i]
 		subIndex := index[i]
 		if len(subIndex) == 0 {
 		if len(subIndex) == 0 {

+ 1 - 0
tools/goctl/model/sql/gen/tag.go

@@ -22,5 +22,6 @@ func genTag(in string) (string, error) {
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
+
 	return output.String(), nil
 	return output.String(), nil
 }
 }

+ 1 - 0
tools/goctl/model/sql/model/ddlmodel.go

@@ -27,6 +27,7 @@ func (m *DDLModel) ShowDDL(table ...string) ([]string, error) {
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
+
 		ddl = append(ddl, resp.DDL)
 		ddl = append(ddl, resp.DDL)
 	}
 	}
 	return ddl, nil
 	return ddl, nil

+ 15 - 0
tools/goctl/model/sql/model/informationschemamodel.go

@@ -8,6 +8,13 @@ type (
 	InformationSchemaModel struct {
 	InformationSchemaModel struct {
 		conn sqlx.SqlConn
 		conn sqlx.SqlConn
 	}
 	}
+	Column struct {
+		Name     string `db:"COLUMN_NAME"`
+		DataType string `db:"DATA_TYPE"`
+		Key      string `db:"COLUMN_KEY"`
+		Extra    string `db:"EXTRA"`
+		Comment  string `db:"COLUMN_COMMENT"`
+	}
 )
 )
 
 
 func NewInformationSchemaModel(conn sqlx.SqlConn) *InformationSchemaModel {
 func NewInformationSchemaModel(conn sqlx.SqlConn) *InformationSchemaModel {
@@ -21,5 +28,13 @@ func (m *InformationSchemaModel) GetAllTables(database string) ([]string, error)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+
 	return tables, nil
 	return tables, nil
 }
 }
+
+func (m *InformationSchemaModel) FindByTableName(db, table string) ([]*Column, error) {
+	querySql := `select COLUMN_NAME,DATA_TYPE,COLUMN_KEY,EXTRA,COLUMN_COMMENT from COLUMNS where TABLE_SCHEMA = ? and TABLE_NAME = ?`
+	var reply []*Column
+	err := m.conn.QueryRows(&reply, querySql, db, table)
+	return reply, err
+}

+ 61 - 2
tools/goctl/model/sql/parser/parser.go

@@ -2,8 +2,10 @@ package parser
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"strings"
 
 
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/converter"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/converter"
+	"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 	"github.com/xwb1989/sqlparser"
 	"github.com/xwb1989/sqlparser"
 )
 )
@@ -34,7 +36,6 @@ type (
 		Name         stringx.String
 		Name         stringx.String
 		DataBaseType string
 		DataBaseType string
 		DataType     string
 		DataType     string
-		IsKey        bool
 		IsPrimaryKey bool
 		IsPrimaryKey bool
 		IsUniqueKey  bool
 		IsUniqueKey  bool
 		Comment      string
 		Comment      string
@@ -123,7 +124,6 @@ func Parse(ddl string) (*Table, error) {
 		field.Comment = comment
 		field.Comment = comment
 		key, ok := keyMap[column.Name.String()]
 		key, ok := keyMap[column.Name.String()]
 		if ok {
 		if ok {
-			field.IsKey = true
 			field.IsPrimaryKey = key == primary
 			field.IsPrimaryKey = key == primary
 			field.IsUniqueKey = key == unique
 			field.IsUniqueKey = key == unique
 			if field.IsPrimaryKey {
 			if field.IsPrimaryKey {
@@ -151,3 +151,62 @@ func (t *Table) ContainsTime() bool {
 	}
 	}
 	return false
 	return false
 }
 }
+
+func ConvertColumn(db, table string, in []*model.Column) (*Table, error) {
+	var reply Table
+	reply.Name = stringx.From(table)
+	keyMap := make(map[string][]*model.Column)
+
+	for _, column := range in {
+		keyMap[column.Key] = append(keyMap[column.Key], column)
+	}
+	primaryColumns := keyMap["PRI"]
+	if len(primaryColumns) == 0 {
+		return nil, fmt.Errorf("database:%s, table %s: missing primary key", db, table)
+	}
+
+	if len(primaryColumns) > 1 {
+		return nil, fmt.Errorf("database:%s, table %s: only one primary key expected", db, table)
+	}
+
+	primaryColumn := primaryColumns[0]
+	primaryFt, err := converter.ConvertDataType(primaryColumn.DataType)
+	if err != nil {
+		return nil, err
+	}
+
+	primaryField := Field{
+		Name:         stringx.From(primaryColumn.Name),
+		DataBaseType: primaryColumn.DataType,
+		DataType:     primaryFt,
+		IsUniqueKey:  true,
+		IsPrimaryKey: true,
+		Comment:      primaryColumn.Comment,
+	}
+	reply.PrimaryKey = Primary{
+		Field:         primaryField,
+		AutoIncrement: strings.Contains(primaryColumn.Extra, "auto_increment"),
+	}
+	for key, columns := range keyMap {
+		for _, item := range columns {
+			dt, err := converter.ConvertDataType(item.DataType)
+			if err != nil {
+				return nil, err
+			}
+
+			f := Field{
+				Name:         stringx.From(item.Name),
+				DataBaseType: item.DataType,
+				DataType:     dt,
+				IsPrimaryKey: primaryColumn.Name == item.Name,
+				Comment:      item.Comment,
+			}
+			if key == "UNI" {
+				f.IsUniqueKey = true
+			}
+			reply.Fields = append(reply.Fields, f)
+		}
+	}
+
+	return &reply, nil
+}

+ 56 - 0
tools/goctl/model/sql/parser/parser_test.go

@@ -4,6 +4,7 @@ import (
 	"testing"
 	"testing"
 
 
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
 )
 )
 
 
 func TestParsePlainText(t *testing.T) {
 func TestParsePlainText(t *testing.T) {
@@ -23,3 +24,58 @@ func TestParseCreateTable(t *testing.T) {
 	assert.Equal(t, "id", table.PrimaryKey.Name.Source())
 	assert.Equal(t, "id", table.PrimaryKey.Name.Source())
 	assert.Equal(t, true, table.ContainsTime())
 	assert.Equal(t, true, table.ContainsTime())
 }
 }
+
+func TestConvertColumn(t *testing.T) {
+	_, err := ConvertColumn("user", "user", []*model.Column{
+		{
+			Name:     "id",
+			DataType: "bigint",
+			Key:      "",
+			Extra:    "",
+			Comment:  "",
+		},
+	})
+	assert.NotNil(t, err)
+	assert.Contains(t, err.Error(), "missing primary key")
+
+	_, err = ConvertColumn("user", "user", []*model.Column{
+		{
+			Name:     "id",
+			DataType: "bigint",
+			Key:      "PRI",
+			Extra:    "",
+			Comment:  "",
+		},
+		{
+			Name:     "mobile",
+			DataType: "varchar",
+			Key:      "PRI",
+			Extra:    "",
+			Comment:  "手机号",
+		},
+	})
+	assert.NotNil(t, err)
+	assert.Contains(t, err.Error(), "only one primary key expected")
+
+	table, err := ConvertColumn("user", "user", []*model.Column{
+		{
+			Name:     "id",
+			DataType: "bigint",
+			Key:      "PRI",
+			Extra:    "auto_increment",
+			Comment:  "",
+		},
+		{
+			Name:     "mobile",
+			DataType: "varchar",
+			Key:      "UNI",
+			Extra:    "",
+			Comment:  "手机号",
+		},
+	})
+	assert.Nil(t, err)
+	assert.True(t, table.PrimaryKey.AutoIncrement && table.PrimaryKey.IsPrimaryKey)
+	assert.Equal(t, "id", table.PrimaryKey.Name.Source())
+	assert.Equal(t, "mobile", table.Fields[1].Name.Source())
+	assert.True(t, table.Fields[1].IsUniqueKey)
+}

+ 15 - 9
tools/goctl/rpc/cli/cli.go

@@ -5,7 +5,6 @@ import (
 	"fmt"
 	"fmt"
 	"path/filepath"
 	"path/filepath"
 
 
-	"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
 	"github.com/urfave/cli"
 	"github.com/urfave/cli"
 )
 )
@@ -16,6 +15,7 @@ import (
 func Rpc(c *cli.Context) error {
 func Rpc(c *cli.Context) error {
 	src := c.String("src")
 	src := c.String("src")
 	out := c.String("dir")
 	out := c.String("dir")
+	style := c.String("style")
 	protoImportPath := c.StringSlice("proto_path")
 	protoImportPath := c.StringSlice("proto_path")
 	if len(src) == 0 {
 	if len(src) == 0 {
 		return errors.New("missing -src")
 		return errors.New("missing -src")
@@ -23,7 +23,13 @@ func Rpc(c *cli.Context) error {
 	if len(out) == 0 {
 	if len(out) == 0 {
 		return errors.New("missing -dir")
 		return errors.New("missing -dir")
 	}
 	}
-	g := generator.NewDefaultRpcGenerator()
+
+	namingStyle, valid := generator.IsNamingValid(style)
+	if !valid {
+		return fmt.Errorf("unexpected naming style %s", style)
+	}
+
+	g := generator.NewDefaultRpcGenerator(namingStyle)
 	return g.Generate(src, out, protoImportPath)
 	return g.Generate(src, out, protoImportPath)
 }
 }
 
 
@@ -36,6 +42,12 @@ func RpcNew(c *cli.Context) error {
 		return fmt.Errorf("unexpected ext: %s", ext)
 		return fmt.Errorf("unexpected ext: %s", ext)
 	}
 	}
 
 
+	style := c.String("style")
+	namingStyle, valid := generator.IsNamingValid(style)
+	if !valid {
+		return fmt.Errorf("expected naming style [lower|camel|snake], but found %s", style)
+	}
+
 	protoName := name + ".proto"
 	protoName := name + ".proto"
 	filename := filepath.Join(".", name, protoName)
 	filename := filepath.Join(".", name, protoName)
 	src, err := filepath.Abs(filename)
 	src, err := filepath.Abs(filename)
@@ -48,13 +60,7 @@ func RpcNew(c *cli.Context) error {
 		return err
 		return err
 	}
 	}
 
 
-	workDir := filepath.Dir(src)
-	_, err = execx.Run("go mod init "+name, workDir)
-	if err != nil {
-		return err
-	}
-
-	g := generator.NewDefaultRpcGenerator()
+	g := generator.NewDefaultRpcGenerator(namingStyle)
 	return g.Generate(src, filepath.Dir(src), nil)
 	return g.Generate(src, filepath.Dir(src), nil)
 }
 }
 
 

+ 0 - 75
tools/goctl/rpc/generator/base/common.pb.go

@@ -1,75 +0,0 @@
-// Code generated by protoc-gen-go. DO NOT EDIT.
-// source: common.proto
-
-package common
-
-import (
-	fmt "fmt"
-	proto "github.com/golang/protobuf/proto"
-	math "math"
-)
-
-// Reference imports to suppress errors if they are not otherwise used.
-var _ = proto.Marshal
-var _ = fmt.Errorf
-var _ = math.Inf
-
-// This is a compile-time assertion to ensure that this generated file
-// is compatible with the proto package it is being compiled against.
-// A compilation error at this line likely means your copy of the
-// proto package needs to be updated.
-const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
-
-type User struct {
-	Name                 string   `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
-	XXX_NoUnkeyedLiteral struct{} `json:"-"`
-	XXX_unrecognized     []byte   `json:"-"`
-	XXX_sizecache        int32    `json:"-"`
-}
-
-func (m *User) Reset()         { *m = User{} }
-func (m *User) String() string { return proto.CompactTextString(m) }
-func (*User) ProtoMessage()    {}
-func (*User) Descriptor() ([]byte, []int) {
-	return fileDescriptor_555bd8c177793206, []int{0}
-}
-
-func (m *User) XXX_Unmarshal(b []byte) error {
-	return xxx_messageInfo_User.Unmarshal(m, b)
-}
-func (m *User) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
-	return xxx_messageInfo_User.Marshal(b, m, deterministic)
-}
-func (m *User) XXX_Merge(src proto.Message) {
-	xxx_messageInfo_User.Merge(m, src)
-}
-func (m *User) XXX_Size() int {
-	return xxx_messageInfo_User.Size(m)
-}
-func (m *User) XXX_DiscardUnknown() {
-	xxx_messageInfo_User.DiscardUnknown(m)
-}
-
-var xxx_messageInfo_User proto.InternalMessageInfo
-
-func (m *User) GetName() string {
-	if m != nil {
-		return m.Name
-	}
-	return ""
-}
-
-func init() {
-	proto.RegisterType((*User)(nil), "common.User")
-}
-
-func init() { proto.RegisterFile("common.proto", fileDescriptor_555bd8c177793206) }
-
-var fileDescriptor_555bd8c177793206 = []byte{
-	// 72 bytes of a gzipped FileDescriptorProto
-	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x49, 0xce, 0xcf, 0xcd,
-	0xcd, 0xcf, 0xd3, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x83, 0xf0, 0x94, 0xa4, 0xb8, 0x58,
-	0x42, 0x8b, 0x53, 0x8b, 0x84, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53, 0x25, 0x18, 0x15, 0x18,
-	0x35, 0x38, 0x83, 0xc0, 0xec, 0x24, 0x36, 0xb0, 0x52, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff,
-	0x2c, 0x6d, 0x58, 0x59, 0x3a, 0x00, 0x00, 0x00,
-}

+ 9 - 2
tools/goctl/rpc/generator/filename.go

@@ -6,6 +6,13 @@ import (
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
 )
 
 
-func formatFilename(filename string) string {
-	return strings.ToLower(stringx.From(filename).ToCamel())
+func formatFilename(filename string, style NamingStyle) string {
+	switch style {
+	case namingCamel:
+		return stringx.From(filename).ToCamel()
+	case namingSnake:
+		return stringx.From(filename).ToSnake()
+	default:
+		return strings.ToLower(stringx.From(filename).ToCamel())
+	}
 }
 }

+ 17 - 0
tools/goctl/rpc/generator/filename_test.go

@@ -0,0 +1,17 @@
+package generator
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestFormatFilename(t *testing.T) {
+	assert.Equal(t, "abc", formatFilename("a_b_c", namingLower))
+	assert.Equal(t, "ABC", formatFilename("a_b_c", namingCamel))
+	assert.Equal(t, "a_b_c", formatFilename("a_b_c", namingSnake))
+	assert.Equal(t, "a", formatFilename("a", namingSnake))
+	assert.Equal(t, "A", formatFilename("a", namingCamel))
+	// no flag to convert to snake
+	assert.Equal(t, "abc", formatFilename("abc", namingSnake))
+}

+ 15 - 13
tools/goctl/rpc/generator/gen.go

@@ -10,16 +10,18 @@ import (
 )
 )
 
 
 type RpcGenerator struct {
 type RpcGenerator struct {
-	g Generator
+	g     Generator
+	style NamingStyle
 }
 }
 
 
-func NewDefaultRpcGenerator() *RpcGenerator {
-	return NewRpcGenerator(NewDefaultGenerator())
+func NewDefaultRpcGenerator(style NamingStyle) *RpcGenerator {
+	return NewRpcGenerator(NewDefaultGenerator(), style)
 }
 }
 
 
-func NewRpcGenerator(g Generator) *RpcGenerator {
+func NewRpcGenerator(g Generator, style NamingStyle) *RpcGenerator {
 	return &RpcGenerator{
 	return &RpcGenerator{
-		g: g,
+		g:     g,
+		style: style,
 	}
 	}
 }
 }
 
 
@@ -55,42 +57,42 @@ func (g *RpcGenerator) Generate(src, target string, protoImportPath []string) er
 		return err
 		return err
 	}
 	}
 
 
-	err = g.g.GenEtc(dirCtx, proto)
+	err = g.g.GenEtc(dirCtx, proto, g.style)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	err = g.g.GenPb(dirCtx, protoImportPath, proto)
+	err = g.g.GenPb(dirCtx, protoImportPath, proto, g.style)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	err = g.g.GenConfig(dirCtx, proto)
+	err = g.g.GenConfig(dirCtx, proto, g.style)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	err = g.g.GenSvc(dirCtx, proto)
+	err = g.g.GenSvc(dirCtx, proto, g.style)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	err = g.g.GenLogic(dirCtx, proto)
+	err = g.g.GenLogic(dirCtx, proto, g.style)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	err = g.g.GenServer(dirCtx, proto)
+	err = g.g.GenServer(dirCtx, proto, g.style)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	err = g.g.GenMain(dirCtx, proto)
+	err = g.g.GenMain(dirCtx, proto, g.style)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	err = g.g.GenCall(dirCtx, proto)
+	err = g.g.GenCall(dirCtx, proto, g.style)
 
 
 	console.NewColorConsole().MarkDone()
 	console.NewColorConsole().MarkDone()
 
 

+ 50 - 104
tools/goctl/rpc/generator/gen_test.go

@@ -1,128 +1,74 @@
 package generator
 package generator
 
 
 import (
 import (
+	"go/build"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
-	"strings"
 	"testing"
 	"testing"
 
 
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/core/logx"
+	"github.com/tal-tech/go-zero/core/stringx"
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
 )
 )
 
 
-func TestRpcGenerateCaseNilImport(t *testing.T) {
+func TestRpcGenerate(t *testing.T) {
 	_ = Clean()
 	_ = Clean()
 	dispatcher := NewDefaultGenerator()
 	dispatcher := NewDefaultGenerator()
-	if err := dispatcher.Prepare(); err == nil {
-		g := NewRpcGenerator(dispatcher)
-		abs, err := filepath.Abs("./test")
-		assert.Nil(t, err)
-
-		err = g.Generate("./test_stream.proto", abs, nil)
-		defer func() {
-			_ = os.RemoveAll(abs)
-		}()
-		assert.Nil(t, err)
-
-		_, err = execx.Run("go test "+abs, abs)
-		assert.Nil(t, err)
+	err := dispatcher.Prepare()
+	if err != nil {
+		logx.Error(err)
+		return
 	}
 	}
-}
-
-func TestRpcGenerateCaseOption(t *testing.T) {
-	_ = Clean()
-	dispatcher := NewDefaultGenerator()
-	if err := dispatcher.Prepare(); err == nil {
-		g := NewRpcGenerator(dispatcher)
-		abs, err := filepath.Abs("./test")
-		assert.Nil(t, err)
-
-		err = g.Generate("./test_option.proto", abs, nil)
-		defer func() {
-			_ = os.RemoveAll(abs)
-		}()
-		assert.Nil(t, err)
-
-		_, err = execx.Run("go test "+abs, abs)
-		assert.Nil(t, err)
+	projectName := stringx.Rand()
+	g := NewRpcGenerator(dispatcher, namingLower)
+
+	// case go path
+	src := filepath.Join(build.Default.GOPATH, "src")
+	_, err = os.Stat(src)
+	if err != nil {
+		return
 	}
 	}
-}
-
-func TestRpcGenerateCaseWordOption(t *testing.T) {
-	_ = Clean()
-	dispatcher := NewDefaultGenerator()
-	if err := dispatcher.Prepare(); err == nil {
-		g := NewRpcGenerator(dispatcher)
-		abs, err := filepath.Abs("./test")
-		assert.Nil(t, err)
 
 
-		err = g.Generate("./test_word_option.proto", abs, nil)
-		defer func() {
-			_ = os.RemoveAll(abs)
-		}()
-		assert.Nil(t, err)
-
-		_, err = execx.Run("go test "+abs, abs)
-		assert.Nil(t, err)
+	projectDir := filepath.Join(src, projectName)
+	srcDir := projectDir
+	defer func() {
+		_ = os.RemoveAll(srcDir)
+	}()
+	err = g.Generate("./test.proto", projectDir, []string{src})
+	assert.Nil(t, err)
+	_, err = execx.Run("go test "+projectName, projectDir)
+	if err != nil {
+		assert.Contains(t, err.Error(), "not in GOROOT")
 	}
 	}
-}
 
 
-// test keyword go
-func TestRpcGenerateCaseGoOption(t *testing.T) {
-	_ = Clean()
-	dispatcher := NewDefaultGenerator()
-	if err := dispatcher.Prepare(); err == nil {
-		g := NewRpcGenerator(dispatcher)
-		abs, err := filepath.Abs("./test")
-		assert.Nil(t, err)
-
-		err = g.Generate("./test_go_option.proto", abs, nil)
-		defer func() {
-			_ = os.RemoveAll(abs)
-		}()
-		assert.Nil(t, err)
-
-		_, err = execx.Run("go test "+abs, abs)
-		assert.Nil(t, err)
+	// case go mod
+	workDir := t.TempDir()
+	name := filepath.Base(workDir)
+	_, err = execx.Run("go mod init "+name, workDir)
+	if err != nil {
+		logx.Error(err)
+		return
 	}
 	}
-}
-
-func TestRpcGenerateCaseImport(t *testing.T) {
-	_ = Clean()
-	dispatcher := NewDefaultGenerator()
-	if err := dispatcher.Prepare(); err == nil {
-		g := NewRpcGenerator(dispatcher)
-		abs, err := filepath.Abs("./test")
-		assert.Nil(t, err)
 
 
-		err = g.Generate("./test_import.proto", abs, []string{"./base"})
-		defer func() {
-			_ = os.RemoveAll(abs)
-		}()
-		assert.Nil(t, err)
-
-		_, err = execx.Run("go test "+abs, abs)
-		assert.True(t, func() bool {
-			return strings.Contains(err.Error(), "package base is not in GOROOT")
-		}())
+	projectDir = filepath.Join(workDir, projectName)
+	err = g.Generate("./test.proto", projectDir, []string{src})
+	assert.Nil(t, err)
+	_, err = execx.Run("go test "+projectName, projectDir)
+	if err != nil {
+		assert.Contains(t, err.Error(), "not in GOROOT")
 	}
 	}
-}
 
 
-func TestRpcGenerateCaseServiceRpcNamingSnake(t *testing.T) {
-	_ = Clean()
-	dispatcher := NewDefaultGenerator()
-	if err := dispatcher.Prepare(); err == nil {
-		g := NewRpcGenerator(dispatcher)
-		abs, err := filepath.Abs("./test")
-		assert.Nil(t, err)
-
-		err = g.Generate("./test_service_rpc_naming_snake.proto", abs, nil)
-		defer func() {
-			_ = os.RemoveAll(abs)
-		}()
-		assert.Nil(t, err)
-
-		_, err = execx.Run("go test "+abs, abs)
-		assert.Nil(t, err)
+	// case not in go mod and go path
+	err = g.Generate("./test.proto", projectDir, []string{src})
+	assert.Nil(t, err)
+	_, err = execx.Run("go test "+projectName, projectDir)
+	if err != nil {
+		assert.Contains(t, err.Error(), "not in GOROOT")
 	}
 	}
+
+	// invalid directory
+	projectDir = filepath.Join(t.TempDir(), ".....")
+	err = g.Generate("./test.proto", projectDir, nil)
+	assert.NotNil(t, err)
 }
 }

+ 5 - 6
tools/goctl/rpc/generator/gencall.go

@@ -59,12 +59,12 @@ func (m *default{{.serviceName}}) {{.method}}(ctx context.Context,in *{{.pbReque
 `
 `
 )
 )
 
 
-func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error {
+func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error {
 	dir := ctx.GetCall()
 	dir := ctx.GetCall()
 	service := proto.Service
 	service := proto.Service
 	head := util.GetHead(proto.Name)
 	head := util.GetHead(proto.Name)
 
 
-	filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", formatFilename(service.Name)))
+	filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", formatFilename(service.Name, namingStyle)))
 	functions, err := g.genFunction(proto.PbPackage, service)
 	functions, err := g.genFunction(proto.PbPackage, service)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -81,13 +81,12 @@ func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error {
 	}
 	}
 
 
 	var alias = collection.NewSet()
 	var alias = collection.NewSet()
-	for _, item := range service.RPC {
-		alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(item.RequestType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.RequestType))))
-		alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(item.ReturnsType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.ReturnsType))))
+	for _, item := range proto.Message {
+		alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(item.Name), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.Name))))
 	}
 	}
 
 
 	err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
 	err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
-		"name":        formatFilename(service.Name),
+		"name":        formatFilename(service.Name, namingStyle),
 		"alias":       strings.Join(alias.KeysStr(), util.NL),
 		"alias":       strings.Join(alias.KeysStr(), util.NL),
 		"head":        head,
 		"head":        head,
 		"filePackage": dir.Base,
 		"filePackage": dir.Base,

+ 0 - 44
tools/goctl/rpc/generator/gencall_test.go

@@ -1,44 +0,0 @@
-package generator
-
-import (
-	"os"
-	"path/filepath"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
-	"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
-)
-
-func TestGenerateCall(t *testing.T) {
-	_ = Clean()
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_stream.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	err = g.Prepare()
-	if err != nil {
-		return
-	}
-	err = g.GenCall(dirCtx, proto)
-	assert.Nil(t, err)
-}

+ 2 - 2
tools/goctl/rpc/generator/genconfig.go

@@ -18,9 +18,9 @@ type Config struct {
 }
 }
 `
 `
 
 
-func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto) error {
+func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error {
 	dir := ctx.GetConfig()
 	dir := ctx.GetConfig()
-	fileName := filepath.Join(dir.Filename, formatFilename("config")+".go")
+	fileName := filepath.Join(dir.Filename, formatFilename("config", namingStyle)+".go")
 	if util.FileExists(fileName) {
 	if util.FileExists(fileName) {
 		return nil
 		return nil
 	}
 	}

+ 0 - 48
tools/goctl/rpc/generator/genconfig_test.go

@@ -1,48 +0,0 @@
-package generator
-
-import (
-	"os"
-	"path/filepath"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
-	"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
-)
-
-func TestGenerateConfig(t *testing.T) {
-	_ = Clean()
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_stream.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	err = g.Prepare()
-	if err != nil {
-		return
-	}
-	err = g.GenConfig(dirCtx, proto)
-	assert.Nil(t, err)
-
-	// test file exists
-	err = g.GenConfig(dirCtx, proto)
-	assert.Nil(t, err)
-}

+ 8 - 8
tools/goctl/rpc/generator/generator.go

@@ -4,12 +4,12 @@ import "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
 
 
 type Generator interface {
 type Generator interface {
 	Prepare() error
 	Prepare() error
-	GenMain(ctx DirContext, proto parser.Proto) error
-	GenCall(ctx DirContext, proto parser.Proto) error
-	GenEtc(ctx DirContext, proto parser.Proto) error
-	GenConfig(ctx DirContext, proto parser.Proto) error
-	GenLogic(ctx DirContext, proto parser.Proto) error
-	GenServer(ctx DirContext, proto parser.Proto) error
-	GenSvc(ctx DirContext, proto parser.Proto) error
-	GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error
+	GenMain(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
+	GenCall(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
+	GenEtc(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
+	GenConfig(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
+	GenLogic(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
+	GenServer(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
+	GenSvc(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
+	GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, namingStyle NamingStyle) error
 }
 }

+ 5 - 3
tools/goctl/rpc/generator/genetc.go

@@ -3,9 +3,11 @@ package generator
 import (
 import (
 	"fmt"
 	"fmt"
 	"path/filepath"
 	"path/filepath"
+	"strings"
 
 
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
 )
 
 
 const etcTemplate = `Name: {{.serviceName}}.rpc
 const etcTemplate = `Name: {{.serviceName}}.rpc
@@ -16,9 +18,9 @@ Etcd:
   Key: {{.serviceName}}.rpc
   Key: {{.serviceName}}.rpc
 `
 `
 
 
-func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto) error {
+func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error {
 	dir := ctx.GetEtc()
 	dir := ctx.GetEtc()
-	serviceNameLower := formatFilename(ctx.GetMain().Base)
+	serviceNameLower := formatFilename(ctx.GetMain().Base, namingStyle)
 	fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.yaml", serviceNameLower))
 	fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.yaml", serviceNameLower))
 
 
 	text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate)
 	text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate)
@@ -27,6 +29,6 @@ func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto) error {
 	}
 	}
 
 
 	return util.With("etc").Parse(text).SaveTo(map[string]interface{}{
 	return util.With("etc").Parse(text).SaveTo(map[string]interface{}{
-		"serviceName": serviceNameLower,
+		"serviceName": strings.ToLower(stringx.From(ctx.GetMain().Base).ToCamel()),
 	}, fileName, false)
 	}, fileName, false)
 }
 }

+ 0 - 45
tools/goctl/rpc/generator/genetc_test.go

@@ -1,45 +0,0 @@
-package generator
-
-import (
-	"os"
-	"path/filepath"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
-	"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
-)
-
-func TestGenerateEtc(t *testing.T) {
-	_ = Clean()
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_stream.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	err = g.Prepare()
-	if err != nil {
-		return
-	}
-
-	err = g.GenEtc(dirCtx, proto)
-	assert.Nil(t, err)
-}

+ 2 - 2
tools/goctl/rpc/generator/genlogic.go

@@ -46,10 +46,10 @@ func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
 `
 `
 )
 )
 
 
-func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto) error {
+func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error {
 	dir := ctx.GetLogic()
 	dir := ctx.GetLogic()
 	for _, rpc := range proto.Service.RPC {
 	for _, rpc := range proto.Service.RPC {
-		filename := filepath.Join(dir.Filename, formatFilename(rpc.Name+"_logic")+".go")
+		filename := filepath.Join(dir.Filename, formatFilename(rpc.Name+"_logic", namingStyle)+".go")
 		functions, err := g.genLogicFunction(proto.PbPackage, rpc)
 		functions, err := g.genLogicFunction(proto.PbPackage, rpc)
 		if err != nil {
 		if err != nil {
 			return err
 			return err

+ 0 - 44
tools/goctl/rpc/generator/genlogic_test.go

@@ -1,44 +0,0 @@
-package generator
-
-import (
-	"os"
-	"path/filepath"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
-	"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
-)
-
-func TestGenerateLogic(t *testing.T) {
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_stream.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	err = g.Prepare()
-	if err != nil {
-		return
-	}
-
-	err = g.GenLogic(dirCtx, proto)
-	assert.Nil(t, err)
-}

+ 3 - 3
tools/goctl/rpc/generator/genmain.go

@@ -45,9 +45,9 @@ func main() {
 }
 }
 `
 `
 
 
-func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error {
+func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error {
 	dir := ctx.GetMain()
 	dir := ctx.GetMain()
-	serviceNameLower := formatFilename(ctx.GetMain().Base)
+	serviceNameLower := formatFilename(ctx.GetMain().Base, namingStyle)
 	fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.go", serviceNameLower))
 	fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.go", serviceNameLower))
 	imports := make([]string, 0)
 	imports := make([]string, 0)
 	pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
 	pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
@@ -63,7 +63,7 @@ func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error {
 
 
 	return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
 	return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
 		"head":        head,
 		"head":        head,
-		"serviceName": serviceNameLower,
+		"serviceName": strings.ToLower(stringx.From(ctx.GetMain().Base).ToCamel()),
 		"imports":     strings.Join(imports, util.NL),
 		"imports":     strings.Join(imports, util.NL),
 		"pkg":         proto.PbPackage,
 		"pkg":         proto.PbPackage,
 		"serviceNew":  stringx.From(proto.Service.Name).ToCamel(),
 		"serviceNew":  stringx.From(proto.Service.Name).ToCamel(),

+ 0 - 45
tools/goctl/rpc/generator/genmain_test.go

@@ -1,45 +0,0 @@
-package generator
-
-import (
-	"os"
-	"path/filepath"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
-	"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
-)
-
-func TestGenerateMain(t *testing.T) {
-	_ = Clean()
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_stream.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	err = g.Prepare()
-	if err != nil {
-		return
-	}
-
-	err = g.GenMain(dirCtx, proto)
-	assert.Nil(t, err)
-}

+ 1 - 1
tools/goctl/rpc/generator/genpb.go

@@ -9,7 +9,7 @@ import (
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
 )
 )
 
 
-func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error {
+func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, namingStyle NamingStyle) error {
 	dir := ctx.GetPb()
 	dir := ctx.GetPb()
 	cw := new(bytes.Buffer)
 	cw := new(bytes.Buffer)
 	base := filepath.Dir(proto.Src)
 	base := filepath.Dir(proto.Src)

+ 0 - 184
tools/goctl/rpc/generator/genpb_test.go

@@ -1,184 +0,0 @@
-package generator
-
-import (
-	"os"
-	"path/filepath"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
-	"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
-)
-
-func TestGenerateCaseNilImport(t *testing.T) {
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		//_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_stream.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	if err := g.Prepare(); err == nil {
-		targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go")
-		err = g.GenPb(dirCtx, nil, proto)
-		assert.Nil(t, err)
-		assert.True(t, func() bool {
-			return util.FileExists(targetPb)
-		}())
-	}
-}
-
-func TestGenerateCaseImport(t *testing.T) {
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_stream.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	if err := g.Prepare(); err == nil {
-		err = g.GenPb(dirCtx, nil, proto)
-		assert.Nil(t, err)
-
-		targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go")
-		assert.True(t, func() bool {
-			return util.FileExists(targetPb)
-		}())
-	}
-}
-
-func TestGenerateCasePathOption(t *testing.T) {
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_option.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	if err := g.Prepare(); err == nil {
-		err = g.GenPb(dirCtx, nil, proto)
-		assert.Nil(t, err)
-
-		targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_option.pb.go")
-		assert.True(t, func() bool {
-			return util.FileExists(targetPb)
-		}())
-	}
-}
-
-func TestGenerateCaseWordOption(t *testing.T) {
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_word_option.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	if err := g.Prepare(); err == nil {
-
-		err = g.GenPb(dirCtx, nil, proto)
-		assert.Nil(t, err)
-
-		targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_word_option.pb.go")
-		assert.True(t, func() bool {
-			return util.FileExists(targetPb)
-		}())
-	}
-}
-
-// test keyword go
-func TestGenerateCaseGoOption(t *testing.T) {
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_go_option.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	if err := g.Prepare(); err == nil {
-
-		err = g.GenPb(dirCtx, nil, proto)
-		assert.Nil(t, err)
-
-		targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_go_option.pb.go")
-		assert.True(t, func() bool {
-			return util.FileExists(targetPb)
-		}())
-	}
-}

+ 2 - 2
tools/goctl/rpc/generator/genserver.go

@@ -43,7 +43,7 @@ func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) (
 `
 `
 )
 )
 
 
-func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error {
+func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error {
 	dir := ctx.GetServer()
 	dir := ctx.GetServer()
 	logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package)
 	logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package)
 	svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
 	svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
@@ -54,7 +54,7 @@ func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error {
 
 
 	head := util.GetHead(proto.Name)
 	head := util.GetHead(proto.Name)
 	service := proto.Service
 	service := proto.Service
-	serverFile := filepath.Join(dir.Filename, formatFilename(service.Name+"_server")+".go")
+	serverFile := filepath.Join(dir.Filename, formatFilename(service.Name+"_server", namingStyle)+".go")
 	funcList, err := g.genFunctions(proto.PbPackage, service)
 	funcList, err := g.genFunctions(proto.PbPackage, service)
 	if err != nil {
 	if err != nil {
 		return err
 		return err

+ 0 - 45
tools/goctl/rpc/generator/genserver_test.go

@@ -1,45 +0,0 @@
-package generator
-
-import (
-	"os"
-	"path/filepath"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
-	"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
-)
-
-func TestGenerateServer(t *testing.T) {
-	_ = Clean()
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_stream.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	err = g.Prepare()
-	if err != nil {
-		return
-	}
-
-	err = g.GenServer(dirCtx, proto)
-	assert.Nil(t, err)
-}

+ 2 - 2
tools/goctl/rpc/generator/gensvc.go

@@ -23,9 +23,9 @@ func NewServiceContext(c config.Config) *ServiceContext {
 }
 }
 `
 `
 
 
-func (g *defaultGenerator) GenSvc(ctx DirContext, _ parser.Proto) error {
+func (g *defaultGenerator) GenSvc(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error {
 	dir := ctx.GetSvc()
 	dir := ctx.GetSvc()
-	fileName := filepath.Join(dir.Filename, formatFilename("service_context")+".go")
+	fileName := filepath.Join(dir.Filename, formatFilename("service_context", namingStyle)+".go")
 	text, err := util.LoadTemplate(category, svcTemplateFile, svcTemplate)
 	text, err := util.LoadTemplate(category, svcTemplateFile, svcTemplate)
 	if err != nil {
 	if err != nil {
 		return err
 		return err

+ 0 - 40
tools/goctl/rpc/generator/gensvc_test.go

@@ -1,40 +0,0 @@
-package generator
-
-import (
-	"os"
-	"path/filepath"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
-	"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
-)
-
-func TestGenerateSvc(t *testing.T) {
-	_ = Clean()
-	project := "stream"
-	abs, err := filepath.Abs("./test")
-	assert.Nil(t, err)
-
-	dir := filepath.Join(abs, project)
-	err = util.MkdirIfNotExist(dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(abs)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test_stream.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-
-	g := NewDefaultGenerator()
-	err = g.GenSvc(dirCtx, proto)
-	assert.Nil(t, err)
-}

+ 0 - 130
tools/goctl/rpc/generator/mkdir_test.go

@@ -1,130 +0,0 @@
-package generator
-
-import (
-	"go/build"
-	"os"
-	"path/filepath"
-	"strings"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-	"github.com/tal-tech/go-zero/core/stringx"
-	"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
-	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
-	"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
-)
-
-func TestMkDirInGoPath(t *testing.T) {
-	dft := build.Default
-	gp := dft.GOPATH
-	if len(gp) == 0 {
-		return
-	}
-	projectName := stringx.Rand()
-	dir := filepath.Join(gp, "src", projectName)
-	err := util.MkdirIfNotExist(dir)
-	if err != nil {
-		return
-	}
-
-	defer func() {
-		_ = os.RemoveAll(dir)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-	internal := filepath.Join(dir, "internal")
-	assert.True(t, true, func() bool {
-		return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package
-	}())
-	assert.True(t, true, func() bool {
-		return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package
-	}())
-	assert.True(t, true, func() bool {
-		return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package
-	}())
-}
-
-func TestMkDirInGoMod(t *testing.T) {
-	dft := build.Default
-	gp := dft.GOPATH
-	if len(gp) == 0 {
-		return
-	}
-	projectName := stringx.Rand()
-	dir := filepath.Join(gp, "src", projectName)
-	err := util.MkdirIfNotExist(dir)
-	if err != nil {
-		return
-	}
-
-	_, err = execx.Run("go mod init "+projectName, dir)
-	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(dir)
-	}()
-
-	projectCtx, err := ctx.Prepare(dir)
-	assert.Nil(t, err)
-
-	p := parser.NewDefaultProtoParser()
-	proto, err := p.Parse("./test.proto")
-	assert.Nil(t, err)
-
-	dirCtx, err := mkdir(projectCtx, proto)
-	assert.Nil(t, err)
-	internal := filepath.Join(dir, "internal")
-	assert.True(t, true, func() bool {
-		return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package
-	}())
-	assert.True(t, true, func() bool {
-		return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package
-	}())
-	assert.True(t, true, func() bool {
-		return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package
-	}())
-	assert.True(t, true, func() bool {
-		return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package
-	}())
-}

+ 24 - 0
tools/goctl/rpc/generator/naming.go

@@ -0,0 +1,24 @@
+package generator
+
+type NamingStyle = string
+
+const (
+	namingLower NamingStyle = "lower"
+	namingCamel NamingStyle = "camel"
+	namingSnake NamingStyle = "snake"
+)
+
+// IsNamingValid validates whether the namingStyle is valid or not,return
+// namingStyle and true if it is valid, or else return empty string
+// and false, and it is a valid value even namingStyle is empty string
+func IsNamingValid(namingStyle string) (NamingStyle, bool) {
+	if len(namingStyle) == 0 {
+		namingStyle = namingLower
+	}
+	switch namingStyle {
+	case namingLower, namingCamel, namingSnake:
+		return namingStyle, true
+	default:
+		return "", false
+	}
+}

+ 25 - 0
tools/goctl/rpc/generator/naming_test.go

@@ -0,0 +1,25 @@
+package generator
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestIsNamingValid(t *testing.T) {
+	style, valid := IsNamingValid("")
+	assert.True(t, valid)
+	assert.Equal(t, namingLower, style)
+
+	_, valid = IsNamingValid("lower1")
+	assert.False(t, valid)
+
+	_, valid = IsNamingValid("lower")
+	assert.True(t, valid)
+
+	_, valid = IsNamingValid("snake")
+	assert.True(t, valid)
+
+	_, valid = IsNamingValid("camel")
+	assert.True(t, valid)
+}

+ 7 - 8
tools/goctl/rpc/generator/prototmpl_test.go

@@ -1,7 +1,6 @@
 package generator
 package generator
 
 
 import (
 import (
-	"os"
 	"path/filepath"
 	"path/filepath"
 	"testing"
 	"testing"
 
 
@@ -9,13 +8,13 @@ import (
 )
 )
 
 
 func TestProtoTmpl(t *testing.T) {
 func TestProtoTmpl(t *testing.T) {
-	out, err := filepath.Abs("./test/test.proto")
+	_ = Clean()
+	// exists dir
+	err := ProtoTmpl(t.TempDir())
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	defer func() {
-		_ = os.RemoveAll(filepath.Dir(out))
-	}()
-	err = ProtoTmpl(out)
-	assert.Nil(t, err)
-	_, err = os.Stat(out)
+
+	// not exist dir
+	dir := filepath.Join(t.TempDir(), "test")
+	err = ProtoTmpl(dir)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 }
 }

+ 83 - 65
tools/goctl/rpc/generator/template_test.go

@@ -2,6 +2,7 @@ package generator
 
 
 import (
 import (
 	"io/ioutil"
 	"io/ioutil"
+	"os"
 	"path/filepath"
 	"path/filepath"
 	"testing"
 	"testing"
 
 
@@ -10,87 +11,104 @@ import (
 )
 )
 
 
 func TestGenTemplates(t *testing.T) {
 func TestGenTemplates(t *testing.T) {
-	err := util.InitTemplates(category, templates)
+	_ = Clean()
+	err := GenTemplates(nil)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	dir, err := util.GetTemplateDir(category)
-	assert.Nil(t, err)
-	file := filepath.Join(dir, "main.tpl")
-	data, err := ioutil.ReadFile(file)
-	assert.Nil(t, err)
-	assert.Equal(t, string(data), mainTemplate)
 }
 }
 
 
 func TestRevertTemplate(t *testing.T) {
 func TestRevertTemplate(t *testing.T) {
-	name := "main.tpl"
-	err := util.InitTemplates(category, templates)
-	assert.Nil(t, err)
-
-	dir, err := util.GetTemplateDir(category)
-	assert.Nil(t, err)
-
-	file := filepath.Join(dir, name)
-	data, err := ioutil.ReadFile(file)
-	assert.Nil(t, err)
-
-	modifyData := string(data) + "modify"
-	err = util.CreateTemplate(category, name, modifyData)
-	assert.Nil(t, err)
-
-	data, err = ioutil.ReadFile(file)
-	assert.Nil(t, err)
-
-	assert.Equal(t, string(data), modifyData)
-
-	assert.Nil(t, RevertTemplate(name))
-
-	data, err = ioutil.ReadFile(file)
-	assert.Nil(t, err)
-	assert.Equal(t, mainTemplate, string(data))
+	_ = Clean()
+	err := GenTemplates(nil)
+	assert.Nil(t, err)
+	fp, err := util.GetTemplateDir(category)
+	if err != nil {
+		return
+	}
+	mainTpl := filepath.Join(fp, mainTemplateFile)
+	data, err := ioutil.ReadFile(mainTpl)
+	if err != nil {
+		return
+	}
+	assert.Equal(t, templates[mainTemplateFile], string(data))
+
+	err = RevertTemplate("test")
+	if err != nil {
+		assert.Equal(t, "test: no such file name", err.Error())
+	}
+
+	err = ioutil.WriteFile(mainTpl, []byte("modify"), os.ModePerm)
+	if err != nil {
+		return
+	}
+
+	data, err = ioutil.ReadFile(mainTpl)
+	if err != nil {
+		return
+	}
+	assert.Equal(t, "modify", string(data))
+
+	err = RevertTemplate(mainTemplateFile)
+	assert.Nil(t, err)
+
+	data, err = ioutil.ReadFile(mainTpl)
+	if err != nil {
+		return
+	}
+	assert.Equal(t, templates[mainTemplateFile], string(data))
 }
 }
 
 
 func TestClean(t *testing.T) {
 func TestClean(t *testing.T) {
-	name := "main.tpl"
-	err := util.InitTemplates(category, templates)
+	_ = Clean()
+	err := GenTemplates(nil)
+	assert.Nil(t, err)
+	fp, err := util.GetTemplateDir(category)
+	if err != nil {
+		return
+	}
+	mainTpl := filepath.Join(fp, mainTemplateFile)
+	_, err = os.Stat(mainTpl)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
-	assert.Nil(t, Clean())
-
-	dir, err := util.GetTemplateDir(category)
+	err = Clean()
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
-	file := filepath.Join(dir, name)
-	_, err = ioutil.ReadFile(file)
+	_, err = os.Stat(mainTpl)
 	assert.NotNil(t, err)
 	assert.NotNil(t, err)
 }
 }
 
 
 func TestUpdate(t *testing.T) {
 func TestUpdate(t *testing.T) {
-	name := "main.tpl"
-	err := util.InitTemplates(category, templates)
-	assert.Nil(t, err)
-
-	dir, err := util.GetTemplateDir(category)
-	assert.Nil(t, err)
-
-	file := filepath.Join(dir, name)
-	data, err := ioutil.ReadFile(file)
-	assert.Nil(t, err)
-
-	modifyData := string(data) + "modify"
-	err = util.CreateTemplate(category, name, modifyData)
-	assert.Nil(t, err)
-
-	data, err = ioutil.ReadFile(file)
-	assert.Nil(t, err)
-
-	assert.Equal(t, string(data), modifyData)
-
-	assert.Nil(t, Update(category))
-
-	data, err = ioutil.ReadFile(file)
-	assert.Nil(t, err)
-	assert.Equal(t, mainTemplate, string(data))
+	_ = Clean()
+	err := GenTemplates(nil)
+	assert.Nil(t, err)
+	fp, err := util.GetTemplateDir(category)
+	if err != nil {
+		return
+	}
+	mainTpl := filepath.Join(fp, mainTemplateFile)
+
+	err = ioutil.WriteFile(mainTpl, []byte("modify"), os.ModePerm)
+	if err != nil {
+		return
+	}
+
+	data, err := ioutil.ReadFile(mainTpl)
+	if err != nil {
+		return
+	}
+	assert.Equal(t, "modify", string(data))
+
+	err = Update(category)
+	assert.Nil(t, err)
+
+	data, err = ioutil.ReadFile(mainTpl)
+	if err != nil {
+		return
+	}
+	assert.Equal(t, templates[mainTemplateFile], string(data))
 }
 }
 
 
 func TestGetCategory(t *testing.T) {
 func TestGetCategory(t *testing.T) {
-	assert.Equal(t, category, GetCategory())
+	_ = Clean()
+	result := GetCategory()
+	assert.Equal(t, category, result)
 }
 }

+ 50 - 13
tools/goctl/rpc/generator/test.proto

@@ -2,24 +2,61 @@
 syntax = "proto3";
 syntax = "proto3";
 
 
 package test;
 package test;
-option go_package = "go";
 
 
-import "test_base.proto";
+import "base/common.proto";
+import "google/protobuf/any.proto";
 
 
-message TestMessage {
-  base.CommonReq req = 1;
+option go_package = "github.com/test";
+
+message Req {
+  string in = 1;
+  common.User user = 2;
+  google.protobuf.Any object = 4;
+}
+
+message Reply {
+  string out = 1;
 }
 }
-message TestReq {}
-message TestReply {
-  base.CommonReply reply = 2;
+
+message snake_req {}
+
+message snake_reply {}
+
+message CamelReq{}
+
+message CamelReply{}
+
+message EnumMessage {
+  enum Enum {
+    unknown = 0;
+    male = 1;
+    female = 2;
+  }
+}
+
+message CommonReply{}
+
+message MapReq{
+  map<string, string> m = 1;
 }
 }
 
 
-enum TestEnum {
-  unknown = 0;
-  male = 1;
-  female = 2;
+message RepeatedReq{
+  repeated string id = 1;
 }
 }
 
 
-service TestService {
-  rpc TestRpc (TestReq) returns (TestReply);
+service Test_Service {
+  // service
+  rpc Service (Req) returns (Reply);
+  // greet service
+  rpc GreetService (Req) returns (Reply);
+  // case snake
+  rpc snake_service (snake_req) returns (snake_reply);
+  // case camel
+  rpc CamelService (CamelReq) returns (CamelReply);
+  // case enum
+  rpc EnumService (EnumMessage) returns (CommonReply);
+  // case map
+  rpc MapService (MapReq) returns (CommonReply);
+  // case repeated
+  rpc RepeatedService (RepeatedReq) returns (CommonReply);
 }
 }

+ 0 - 12
tools/goctl/rpc/generator/test_base.proto

@@ -1,12 +0,0 @@
-// test proto
-syntax = "proto3";
-
-package base;
-
-message CommonReq {
-  string in = 1;
-}
-
-message CommonReply {
-  string out = 1;
-}

+ 0 - 18
tools/goctl/rpc/generator/test_go_option.proto

@@ -1,18 +0,0 @@
-// test proto
-syntax = "proto3";
-
-package stream;
-
-option go_package="go";
-
-message StreamReq {
-  string name = 1;
-}
-
-message StreamResp {
-  string greet = 1;
-}
-
-service StreamGreeter {
-  rpc greet (StreamReq) returns (StreamResp);
-}

+ 0 - 18
tools/goctl/rpc/generator/test_import.proto

@@ -1,18 +0,0 @@
-// test proto
-syntax = "proto3";
-
-package greet;
-import "base/common.proto";
-
-message In {
-  string name = 1;
-  common.User user = 2;
-}
-
-message Out {
-  string greet = 1;
-}
-
-service StreamGreeter {
-  rpc greet (In) returns (Out);
-}

+ 0 - 18
tools/goctl/rpc/generator/test_option.proto

@@ -1,18 +0,0 @@
-// test proto
-syntax = "proto3";
-
-package stream;
-
-option go_package="github.com/tal-tech/go-zero";
-
-message StreamReq {
-  string name = 1;
-}
-
-message StreamResp {
-  string greet = 1;
-}
-
-service StreamGreeter {
-  rpc greet (StreamReq) returns (StreamResp);
-}

+ 0 - 27
tools/goctl/rpc/generator/test_service_rpc_naming_snake.proto

@@ -1,27 +0,0 @@
-// test proto
-syntax = "proto3";
-
-package snake_package;
-
-message StreamReq {
-  string name = 1;
-}
-
-message Stream_Resp {
-  string greet = 1;
-}
-
-message lowercase {
-  string in = 1;
-  string lower = 2;
-}
-
-message CamelCase {
-  string Camel = 1;
-}
-
-service Stream_Greeter {
-  rpc snake_service(StreamReq) returns (Stream_Resp);
-  rpc ServiceCamelCase(CamelCase) returns (CamelCase);
-  rpc servicelowercase(lowercase) returns (lowercase);
-}

+ 0 - 17
tools/goctl/rpc/generator/test_stream.proto

@@ -1,17 +0,0 @@
-// test proto
-syntax = "proto3";
-
-package stream;
-
-message StreamReq {
-  string name = 1;
-}
-
-message StreamResp {
-  string greet = 1;
-}
-
-service StreamGreeter {
-  // greet service
-  rpc greet (StreamReq) returns (StreamResp);
-}

+ 0 - 18
tools/goctl/rpc/generator/test_word_option.proto

@@ -1,18 +0,0 @@
-// test proto
-syntax = "proto3";
-
-package stream;
-
-option go_package="user";
-
-message StreamReq {
-  string name = 1;
-}
-
-message StreamResp {
-  string greet = 1;
-}
-
-service StreamGreeter {
-  rpc greet(StreamReq) returns (StreamResp);
-}

+ 2 - 0
tools/goctl/util/stringx/string.go

@@ -29,9 +29,11 @@ func (s String) IsEmptyOrSpace() bool {
 func (s String) Lower() string {
 func (s String) Lower() string {
 	return strings.ToLower(s.source)
 	return strings.ToLower(s.source)
 }
 }
+
 func (s String) Upper() string {
 func (s String) Upper() string {
 	return strings.ToUpper(s.source)
 	return strings.ToUpper(s.source)
 }
 }
+
 func (s String) Title() string {
 func (s String) Title() string {
 	if s.IsEmptyOrSpace() {
 	if s.IsEmptyOrSpace() {
 		return s.source
 		return s.source