1
0
Эх сурвалжийг харах

rpc generation fix (#184)

* reactor alert

* optimize

* add test case

* update the target directory in case proto contains option

* fix missing comments and format code
Keson 4 жил өмнө
parent
commit
856b5aadb1

+ 2 - 2
tools/goctl/rpc/cli/cli.go

@@ -18,10 +18,10 @@ func Rpc(c *cli.Context) error {
 	out := c.String("dir")
 	out := c.String("dir")
 	protoImportPath := c.StringSlice("proto_path")
 	protoImportPath := c.StringSlice("proto_path")
 	if len(src) == 0 {
 	if len(src) == 0 {
-		return errors.New("the proto source can not be nil")
+		return errors.New("missing -src")
 	}
 	}
 	if len(out) == 0 {
 	if len(out) == 0 {
-		return errors.New("the target directory can not be nil")
+		return errors.New("missing -dir")
 	}
 	}
 	g := generator.NewDefaultRpcGenerator()
 	g := generator.NewDefaultRpcGenerator()
 	return g.Generate(src, out, protoImportPath)
 	return g.Generate(src, out, protoImportPath)

+ 24 - 0
tools/goctl/rpc/generator/gen_test.go

@@ -11,6 +11,7 @@ import (
 )
 )
 
 
 func TestRpcGenerateCaseNilImport(t *testing.T) {
 func TestRpcGenerateCaseNilImport(t *testing.T) {
+	_ = Clean()
 	dispatcher := NewDefaultGenerator()
 	dispatcher := NewDefaultGenerator()
 	if err := dispatcher.Prepare(); err == nil {
 	if err := dispatcher.Prepare(); err == nil {
 		g := NewRpcGenerator(dispatcher)
 		g := NewRpcGenerator(dispatcher)
@@ -29,6 +30,7 @@ func TestRpcGenerateCaseNilImport(t *testing.T) {
 }
 }
 
 
 func TestRpcGenerateCaseOption(t *testing.T) {
 func TestRpcGenerateCaseOption(t *testing.T) {
+	_ = Clean()
 	dispatcher := NewDefaultGenerator()
 	dispatcher := NewDefaultGenerator()
 	if err := dispatcher.Prepare(); err == nil {
 	if err := dispatcher.Prepare(); err == nil {
 		g := NewRpcGenerator(dispatcher)
 		g := NewRpcGenerator(dispatcher)
@@ -47,6 +49,7 @@ func TestRpcGenerateCaseOption(t *testing.T) {
 }
 }
 
 
 func TestRpcGenerateCaseWordOption(t *testing.T) {
 func TestRpcGenerateCaseWordOption(t *testing.T) {
+	_ = Clean()
 	dispatcher := NewDefaultGenerator()
 	dispatcher := NewDefaultGenerator()
 	if err := dispatcher.Prepare(); err == nil {
 	if err := dispatcher.Prepare(); err == nil {
 		g := NewRpcGenerator(dispatcher)
 		g := NewRpcGenerator(dispatcher)
@@ -66,6 +69,7 @@ func TestRpcGenerateCaseWordOption(t *testing.T) {
 
 
 // test keyword go
 // test keyword go
 func TestRpcGenerateCaseGoOption(t *testing.T) {
 func TestRpcGenerateCaseGoOption(t *testing.T) {
+	_ = Clean()
 	dispatcher := NewDefaultGenerator()
 	dispatcher := NewDefaultGenerator()
 	if err := dispatcher.Prepare(); err == nil {
 	if err := dispatcher.Prepare(); err == nil {
 		g := NewRpcGenerator(dispatcher)
 		g := NewRpcGenerator(dispatcher)
@@ -84,6 +88,7 @@ func TestRpcGenerateCaseGoOption(t *testing.T) {
 }
 }
 
 
 func TestRpcGenerateCaseImport(t *testing.T) {
 func TestRpcGenerateCaseImport(t *testing.T) {
+	_ = Clean()
 	dispatcher := NewDefaultGenerator()
 	dispatcher := NewDefaultGenerator()
 	if err := dispatcher.Prepare(); err == nil {
 	if err := dispatcher.Prepare(); err == nil {
 		g := NewRpcGenerator(dispatcher)
 		g := NewRpcGenerator(dispatcher)
@@ -102,3 +107,22 @@ func TestRpcGenerateCaseImport(t *testing.T) {
 		}())
 		}())
 	}
 	}
 }
 }
+
+func TestRpcGenerateCaseServiceRpcNamingSnake(t *testing.T) {
+	_ = Clean()
+	dispatcher := NewDefaultGenerator()
+	if err := dispatcher.Prepare(); err == nil {
+		g := NewRpcGenerator(dispatcher)
+		abs, err := filepath.Abs("./test")
+		assert.Nil(t, err)
+
+		err = g.Generate("./test_service_rpc_naming_snake.proto", abs, nil)
+		defer func() {
+			_ = os.RemoveAll(abs)
+		}()
+		assert.Nil(t, err)
+
+		_, err = execx.Run("go test "+abs, abs)
+		assert.Nil(t, err)
+	}
+}

+ 7 - 6
tools/goctl/rpc/generator/gencall.go

@@ -52,7 +52,7 @@ func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
 
 
 	callFunctionTemplate = `
 	callFunctionTemplate = `
 {{if .hasComment}}{{.comment}}{{end}}
 {{if .hasComment}}{{.comment}}{{end}}
-func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) {
+func (m *default{{.serviceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) {
 	client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
 	client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
 	return client.{{.method}}(ctx, in)
 	return client.{{.method}}(ctx, in)
 }
 }
@@ -90,9 +90,9 @@ func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error {
 		"name":        formatFilename(service.Name),
 		"name":        formatFilename(service.Name),
 		"alias":       strings.Join(alias.KeysStr(), util.NL),
 		"alias":       strings.Join(alias.KeysStr(), util.NL),
 		"head":        head,
 		"head":        head,
-		"filePackage": formatFilename(service.Name),
+		"filePackage": dir.Base,
 		"package":     fmt.Sprintf(`"%s"`, ctx.GetPb().Package),
 		"package":     fmt.Sprintf(`"%s"`, ctx.GetPb().Package),
-		"serviceName": parser.CamelCase(service.Name),
+		"serviceName": stringx.From(service.Name).ToCamel(),
 		"functions":   strings.Join(functions, util.NL),
 		"functions":   strings.Join(functions, util.NL),
 		"interface":   strings.Join(iFunctions, util.NL),
 		"interface":   strings.Join(iFunctions, util.NL),
 	}, filename, true)
 	}, filename, true)
@@ -109,8 +109,9 @@ func (g *defaultGenerator) genFunction(goPackage string, service parser.Service)
 
 
 		comment := parser.GetComment(rpc.Doc())
 		comment := parser.GetComment(rpc.Doc())
 		buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
 		buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
-			"rpcServiceName": stringx.From(service.Name).Title(),
-			"method":         stringx.From(rpc.Name).Title(),
+			"serviceName":    stringx.From(service.Name).ToCamel(),
+			"rpcServiceName": parser.CamelCase(service.Name),
+			"method":         parser.CamelCase(rpc.Name),
 			"package":        goPackage,
 			"package":        goPackage,
 			"pbRequest":      parser.CamelCase(rpc.RequestType),
 			"pbRequest":      parser.CamelCase(rpc.RequestType),
 			"pbResponse":     parser.CamelCase(rpc.ReturnsType),
 			"pbResponse":     parser.CamelCase(rpc.ReturnsType),
@@ -140,7 +141,7 @@ func (g *defaultGenerator) getInterfaceFuncs(service parser.Service) ([]string,
 			map[string]interface{}{
 			map[string]interface{}{
 				"hasComment": len(comment) > 0,
 				"hasComment": len(comment) > 0,
 				"comment":    comment,
 				"comment":    comment,
-				"method":     stringx.From(rpc.Name).Title(),
+				"method":     parser.CamelCase(rpc.Name),
 				"pbRequest":  parser.CamelCase(rpc.RequestType),
 				"pbRequest":  parser.CamelCase(rpc.RequestType),
 				"pbResponse": parser.CamelCase(rpc.ReturnsType),
 				"pbResponse": parser.CamelCase(rpc.ReturnsType),
 			})
 			})

+ 1 - 1
tools/goctl/rpc/generator/genlogic.go

@@ -63,7 +63,7 @@ func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto) error {
 			return err
 			return err
 		}
 		}
 		err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
 		err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
-			"logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).Title()),
+			"logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel()),
 			"functions": functions,
 			"functions": functions,
 			"imports":   strings.Join(imports.KeysStr(), util.NL),
 			"imports":   strings.Join(imports.KeysStr(), util.NL),
 		}, filename, false)
 		}, filename, false)

+ 3 - 1
tools/goctl/rpc/generator/genmain.go

@@ -7,6 +7,7 @@ import (
 
 
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
 )
 
 
 const mainTemplate = `{{.head}}
 const mainTemplate = `{{.head}}
@@ -32,7 +33,7 @@ func main() {
 	var c config.Config
 	var c config.Config
 	conf.MustLoad(*configFile, &c)
 	conf.MustLoad(*configFile, &c)
 	ctx := svc.NewServiceContext(c)
 	ctx := svc.NewServiceContext(c)
-	srv := server.New{{.service}}Server(ctx)
+	srv := server.New{{.serviceNew}}Server(ctx)
 
 
 	s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
 	s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
 		{{.pkg}}.Register{{.service}}Server(grpcServer, srv)
 		{{.pkg}}.Register{{.service}}Server(grpcServer, srv)
@@ -65,6 +66,7 @@ func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error {
 		"serviceName": serviceNameLower,
 		"serviceName": serviceNameLower,
 		"imports":     strings.Join(imports, util.NL),
 		"imports":     strings.Join(imports, util.NL),
 		"pkg":         proto.PbPackage,
 		"pkg":         proto.PbPackage,
+		"serviceNew":  stringx.From(proto.Service.Name).ToCamel(),
 		"service":     parser.CamelCase(proto.Service.Name),
 		"service":     parser.CamelCase(proto.Service.Name),
 	}, fileName, false)
 	}, fileName, false)
 }
 }

+ 1 - 1
tools/goctl/rpc/generator/genpb.go

@@ -20,7 +20,7 @@ func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto
 	cw.WriteString(" -I=" + base)
 	cw.WriteString(" -I=" + base)
 	cw.WriteString(" " + proto.Name)
 	cw.WriteString(" " + proto.Name)
 	if strings.Contains(proto.GoPackage, "/") {
 	if strings.Contains(proto.GoPackage, "/") {
-		cw.WriteString(" --go_out=plugins=grpc:" + ctx.GetInternal().Filename)
+		cw.WriteString(" --go_out=plugins=grpc:" + ctx.GetMain().Filename)
 	} else {
 	} else {
 		cw.WriteString(" --go_out=plugins=grpc:" + dir.Filename)
 		cw.WriteString(" --go_out=plugins=grpc:" + dir.Filename)
 	}
 	}

+ 3 - 3
tools/goctl/rpc/generator/genserver.go

@@ -67,7 +67,7 @@ func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error {
 
 
 	err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
 	err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
 		"head":    head,
 		"head":    head,
-		"server":  stringx.From(service.Name).Title(),
+		"server":  stringx.From(service.Name).ToCamel(),
 		"imports": strings.Join(imports.KeysStr(), util.NL),
 		"imports": strings.Join(imports.KeysStr(), util.NL),
 		"funcs":   strings.Join(funcList, util.NL),
 		"funcs":   strings.Join(funcList, util.NL),
 	}, serverFile, true)
 	}, serverFile, true)
@@ -84,8 +84,8 @@ func (g *defaultGenerator) genFunctions(goPackage string, service parser.Service
 
 
 		comment := parser.GetComment(rpc.Doc())
 		comment := parser.GetComment(rpc.Doc())
 		buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
 		buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
-			"server":     stringx.From(service.Name).Title(),
-			"logicName":  fmt.Sprintf("%sLogic", stringx.From(rpc.Name).Title()),
+			"server":     stringx.From(service.Name).ToCamel(),
+			"logicName":  fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel()),
 			"method":     parser.CamelCase(rpc.Name),
 			"method":     parser.CamelCase(rpc.Name),
 			"request":    fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
 			"request":    fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
 			"response":   fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
 			"response":   fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),

+ 5 - 1
tools/goctl/rpc/generator/mkdir.go

@@ -53,8 +53,12 @@ func mkdir(ctx *ctx.ProjectContext, proto parser.Proto) (DirContext, error) {
 	logicDir := filepath.Join(internalDir, "logic")
 	logicDir := filepath.Join(internalDir, "logic")
 	serverDir := filepath.Join(internalDir, "server")
 	serverDir := filepath.Join(internalDir, "server")
 	svcDir := filepath.Join(internalDir, "svc")
 	svcDir := filepath.Join(internalDir, "svc")
-	pbDir := filepath.Join(internalDir, proto.GoPackage)
+	pbDir := filepath.Join(ctx.WorkDir, proto.GoPackage)
 	callDir := filepath.Join(ctx.WorkDir, strings.ToLower(stringx.From(proto.Service.Name).ToCamel()))
 	callDir := filepath.Join(ctx.WorkDir, strings.ToLower(stringx.From(proto.Service.Name).ToCamel()))
+	if strings.ToLower(proto.Service.Name) == strings.ToLower(proto.GoPackage) {
+		callDir = filepath.Join(ctx.WorkDir, strings.ToLower(stringx.From(proto.Service.Name+"_client").ToCamel()))
+	}
+
 	inner[wd] = Dir{
 	inner[wd] = Dir{
 		Filename: ctx.WorkDir,
 		Filename: ctx.WorkDir,
 		Package:  filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(ctx.WorkDir, ctx.Dir))),
 		Package:  filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(ctx.WorkDir, ctx.Dir))),

+ 5 - 5
tools/goctl/rpc/generator/test.proto

@@ -6,11 +6,11 @@ option go_package = "go";
 
 
 import "test_base.proto";
 import "test_base.proto";
 
 
-message TestMessage{
+message TestMessage {
   base.CommonReq req = 1;
   base.CommonReq req = 1;
 }
 }
-message TestReq{}
-message TestReply{
+message TestReq {}
+message TestReply {
   base.CommonReply reply = 2;
   base.CommonReply reply = 2;
 }
 }
 
 
@@ -20,6 +20,6 @@ enum TestEnum {
   female = 2;
   female = 2;
 }
 }
 
 
-service TestService{
-  rpc TestRpc (TestReq)returns(TestReply);
+service TestService {
+  rpc TestRpc (TestReq) returns (TestReply);
 }
 }

+ 1 - 1
tools/goctl/rpc/generator/test_go_option.proto

@@ -14,5 +14,5 @@ message StreamResp {
 }
 }
 
 
 service StreamGreeter {
 service StreamGreeter {
-  rpc greet(StreamReq) returns (StreamResp);
+  rpc greet (StreamReq) returns (StreamResp);
 }
 }

+ 1 - 1
tools/goctl/rpc/generator/test_import.proto

@@ -14,5 +14,5 @@ message Out {
 }
 }
 
 
 service StreamGreeter {
 service StreamGreeter {
-  rpc greet(In) returns (Out);
+  rpc greet (In) returns (Out);
 }
 }

+ 1 - 1
tools/goctl/rpc/generator/test_option.proto

@@ -14,5 +14,5 @@ message StreamResp {
 }
 }
 
 
 service StreamGreeter {
 service StreamGreeter {
-  rpc greet(StreamReq) returns (StreamResp);
+  rpc greet (StreamReq) returns (StreamResp);
 }
 }

+ 27 - 0
tools/goctl/rpc/generator/test_service_rpc_naming_snake.proto

@@ -0,0 +1,27 @@
+// test proto
+syntax = "proto3";
+
+package snake_package;
+
+message StreamReq {
+  string name = 1;
+}
+
+message Stream_Resp {
+  string greet = 1;
+}
+
+message lowercase {
+  string in = 1;
+  string lower = 2;
+}
+
+message CamelCase {
+  string Camel = 1;
+}
+
+service Stream_Greeter {
+  rpc snake_service(StreamReq) returns (Stream_Resp);
+  rpc ServiceCamelCase(CamelCase) returns (CamelCase);
+  rpc servicelowercase(lowercase) returns (lowercase);
+}

+ 2 - 1
tools/goctl/rpc/generator/test_stream.proto

@@ -12,5 +12,6 @@ message StreamResp {
 }
 }
 
 
 service StreamGreeter {
 service StreamGreeter {
-  rpc greet(StreamReq) returns (StreamResp);
+  // greet service
+  rpc greet (StreamReq) returns (StreamResp);
 }
 }

+ 1 - 1
tools/goctl/rpc/parser/comment.go

@@ -6,5 +6,5 @@ func GetComment(comment *proto.Comment) string {
 	if comment == nil {
 	if comment == nil {
 		return ""
 		return ""
 	}
 	}
-	return comment.Message()
+	return "// " + comment.Message()
 }
 }

+ 2 - 2
tools/goctl/rpc/parser/test_invalid_request.proto

@@ -8,6 +8,6 @@ import "base.proto";
 message Reply{}
 message Reply{}
 
 
 
 
-service TestService{
-  rpc TestRpcTwo (base.Req)returns(Reply);
+service TestService {
+  rpc TestRpcTwo (base.Req) returns (Reply);
 }
 }

+ 2 - 2
tools/goctl/rpc/parser/test_invalid_response.proto

@@ -8,6 +8,6 @@ import "base.proto";
 message Req{}
 message Req{}
 
 
 
 
-service TestService{
-  rpc TestRpcTwo (Req)returns(base.Reply);
+service TestService {
+  rpc TestRpcTwo (Req) returns (base.Reply);
 }
 }

+ 5 - 4
tools/goctl/rpc/parser/test_option.proto

@@ -2,9 +2,10 @@ syntax = "proto3";
 
 
 package stream;
 package stream;
 
 
-option go_package="github.com/tal-tech/go-zero";
+option go_package = "github.com/tal-tech/go-zero";
 
 
-message placeholder{}
-service greet{
-  rpc hello(placeholder)returns(placeholder);
+message placeholder {}
+
+service greet {
+  rpc hello (placeholder) returns (placeholder);
 }
 }

+ 4 - 3
tools/goctl/rpc/parser/test_option2.proto

@@ -3,7 +3,8 @@ syntax = "proto3";
 package stream;
 package stream;
 
 
 
 
-message placeholder{}
-service greet{
-  rpc hello(placeholder)returns(placeholder);
+message placeholder {}
+
+service greet {
+  rpc hello (placeholder) returns (placeholder);
 }
 }