Browse Source

To generate grpc stream, fix issue #616 (#815)

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
anqiansong 3 years ago
parent
commit
db87fd3239

+ 40 - 34
tools/goctl/rpc/generator/gen_test.go

@@ -29,7 +29,6 @@ func TestRpcGenerate(t *testing.T) {
 	projectName := stringx.Rand()
 	g := NewRPCGenerator(dispatcher, cfg)
 
-	// case go path
 	src := filepath.Join(build.Default.GOPATH, "src")
 	_, err = os.Stat(src)
 	if err != nil {
@@ -41,45 +40,52 @@ func TestRpcGenerate(t *testing.T) {
 	defer func() {
 		_ = os.RemoveAll(srcDir)
 	}()
-
 	common, err := filepath.Abs(".")
 	assert.Nil(t, err)
 
-	err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
-	assert.Nil(t, err)
-	_, err = execx.Run("go test "+projectName, projectDir)
-	if err != nil {
-		assert.True(t, func() bool {
-			return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
-		}())
-	}
+	// case go path
+	t.Run("GOPATH", func(t *testing.T) {
+		err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
+		assert.Nil(t, err)
+		_, err = execx.Run("go test "+projectName, projectDir)
+		if err != nil {
+			assert.True(t, func() bool {
+				return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
+			}())
+		}
+	})
 
 	// case go mod
-	workDir := t.TempDir()
-	name := filepath.Base(workDir)
-	_, err = execx.Run("go mod init "+name, workDir)
-	if err != nil {
-		logx.Error(err)
-		return
-	}
+	t.Run("GOMOD", func(t *testing.T) {
+		workDir := t.TempDir()
+		name := filepath.Base(workDir)
+		_, err = execx.Run("go mod init "+name, workDir)
+		if err != nil {
+			logx.Error(err)
+			return
+		}
 
-	projectDir = filepath.Join(workDir, projectName)
-	err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
-	assert.Nil(t, err)
-	_, err = execx.Run("go test "+projectName, projectDir)
-	if err != nil {
-		assert.True(t, func() bool {
-			return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
-		}())
-	}
+		projectDir = filepath.Join(workDir, projectName)
+		err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
+		assert.Nil(t, err)
+		_, err = execx.Run("go test "+projectName, projectDir)
+		if err != nil {
+			assert.True(t, func() bool {
+				return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
+			}())
+		}
+
+	})
 
 	// case not in go mod and go path
-	err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
-	assert.Nil(t, err)
-	_, err = execx.Run("go test "+projectName, projectDir)
-	if err != nil {
-		assert.True(t, func() bool {
-			return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
-		}())
-	}
+	t.Run("OTHER", func(t *testing.T) {
+		err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
+		assert.Nil(t, err)
+		_, err = execx.Run("go test "+projectName, projectDir)
+		if err != nil {
+			assert.True(t, func() bool {
+				return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
+			}())
+		}
+	})
 }

+ 28 - 5
tools/goctl/rpc/generator/gencall.go

@@ -49,13 +49,13 @@ func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
 `
 
 	callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
-{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
+{{end}}{{.method}}(ctx context.Context{{if .hasReq}},in *{{.pbRequest}}{{end}}) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error)`
 
 	callFunctionTemplate = `
 {{if .hasComment}}{{.comment}}{{end}}
-func (m *default{{.serviceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) {
+func (m *default{{.serviceName}}) {{.method}}(ctx context.Context{{if .hasReq}},in *{{.pbRequest}}{{end}}) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error) {
 	client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
-	return client.{{.method}}(ctx, in)
+	return client.{{.method}}(ctx,{{if .hasReq}} in{{end}})
 }
 `
 )
@@ -78,7 +78,7 @@ func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf
 		return err
 	}
 
-	iFunctions, err := g.getInterfaceFuncs(service)
+	iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service)
 	if err != nil {
 		return err
 	}
@@ -115,6 +115,14 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
 		}
 
 		comment := parser.GetComment(rpc.Doc())
+		var streamServer string
+		if rpc.StreamsRequest && rpc.StreamsReturns {
+			streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamClient")
+		} else if rpc.StreamsRequest {
+			streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamClient")
+		} else {
+			streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamClient")
+		}
 		buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
 			"serviceName":    stringx.From(service.Name).ToCamel(),
 			"rpcServiceName": parser.CamelCase(service.Name),
@@ -124,6 +132,9 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
 			"pbResponse":     parser.CamelCase(rpc.ReturnsType),
 			"hasComment":     len(comment) > 0,
 			"comment":        comment,
+			"hasReq":         !rpc.StreamsRequest,
+			"notStream":      !rpc.StreamsRequest && !rpc.StreamsReturns,
+			"streamBody":     streamServer,
 		})
 		if err != nil {
 			return nil, err
@@ -134,7 +145,7 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
 	return functions, nil
 }
 
-func (g *DefaultGenerator) getInterfaceFuncs(service parser.Service) ([]string, error) {
+func (g *DefaultGenerator) getInterfaceFuncs(goPackage string, service parser.Service) ([]string, error) {
 	functions := make([]string, 0)
 
 	for _, rpc := range service.RPC {
@@ -144,13 +155,25 @@ func (g *DefaultGenerator) getInterfaceFuncs(service parser.Service) ([]string,
 		}
 
 		comment := parser.GetComment(rpc.Doc())
+		var streamServer string
+		if rpc.StreamsRequest && rpc.StreamsReturns {
+			streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamClient")
+		} else if rpc.StreamsRequest {
+			streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamClient")
+		} else {
+			streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamClient")
+		}
+
 		buffer, err := util.With("interfaceFn").Parse(text).Execute(
 			map[string]interface{}{
 				"hasComment": len(comment) > 0,
 				"comment":    comment,
 				"method":     parser.CamelCase(rpc.Name),
+				"hasReq":     !rpc.StreamsRequest,
 				"pbRequest":  parser.CamelCase(rpc.RequestType),
+				"notStream":  !rpc.StreamsRequest && !rpc.StreamsReturns,
 				"pbResponse": parser.CamelCase(rpc.ReturnsType),
+				"streamBody": streamServer,
 			})
 		if err != nil {
 			return nil, err

+ 17 - 4
tools/goctl/rpc/generator/genlogic.go

@@ -40,10 +40,10 @@ func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logic
 {{.functions}}
 `
 	logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
-func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
+func (l *{{.logicName}}) {{.method}} ({{if .hasReq}}in {{.request}}{{if .stream}},stream {{.streamBody}}{{end}}{{else}}stream {{.streamBody}}{{end}}) ({{if .hasReply}}{{.response}},{{end}} error) {
 	// todo: add your logic here and delete this line
 	
-	return &{{.responseType}}{}, nil
+	return {{if .hasReply}}&{{.responseType}}{},{{end}} nil
 }
 `
 )
@@ -51,6 +51,7 @@ func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
 // GenLogic generates the logic file of the rpc service, which corresponds to the RPC definition items in proto.
 func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
 	dir := ctx.GetLogic()
+	service := proto.Service.Service.Name
 	for _, rpc := range proto.Service.RPC {
 		logicFilename, err := format.FileNamingFormat(cfg.NamingFormat, rpc.Name+"_logic")
 		if err != nil {
@@ -58,7 +59,7 @@ func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *con
 		}
 
 		filename := filepath.Join(dir.Filename, logicFilename+".go")
-		functions, err := g.genLogicFunction(proto.PbPackage, rpc)
+		functions, err := g.genLogicFunction(service, proto.PbPackage, rpc)
 		if err != nil {
 			return err
 		}
@@ -82,7 +83,7 @@ func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *con
 	return nil
 }
 
-func (g *DefaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) (string, error) {
+func (g *DefaultGenerator) genLogicFunction(serviceName string, goPackage string, rpc *parser.RPC) (string, error) {
 	functions := make([]string, 0)
 	text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
 	if err != nil {
@@ -91,12 +92,24 @@ func (g *DefaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) (
 
 	logicName := stringx.From(rpc.Name + "_logic").ToCamel()
 	comment := parser.GetComment(rpc.Doc())
+	var streamServer string
+	if rpc.StreamsRequest && rpc.StreamsReturns {
+		streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_StreamServer")
+	} else if rpc.StreamsRequest {
+		streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_ClientStreamServer")
+	} else {
+		streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_ServerStreamServer")
+	}
 	buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
 		"logicName":    logicName,
 		"method":       parser.CamelCase(rpc.Name),
+		"hasReq":       !rpc.StreamsRequest,
 		"request":      fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
+		"hasReply":     !rpc.StreamsRequest && !rpc.StreamsReturns,
 		"response":     fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
 		"responseType": fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
+		"stream":       rpc.StreamsRequest || rpc.StreamsReturns,
+		"streamBody":   streamServer,
 		"hasComment":   len(comment) > 0,
 		"comment":      comment,
 	})

+ 16 - 3
tools/goctl/rpc/generator/genserver.go

@@ -38,9 +38,9 @@ func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
 `
 	functionTemplate = `
 {{if .hasComment}}{{.comment}}{{end}}
-func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
-	l := logic.New{{.logicName}}(ctx,s.svcCtx)
-	return l.{{.method}}(in)
+func (s *{{.server}}Server) {{.method}} ({{if .notStream}}ctx context.Context,{{if .hasReq}} in {{.request}}{{end}}{{else}}{{if .hasReq}} in {{.request}},{{end}}stream {{.streamBody}}{{end}}) ({{if .notStream}}{{.response}},{{end}}error) {
+	l := logic.New{{.logicName}}({{if .notStream}}ctx,{{else}}stream.Context(),{{end}}s.svcCtx)
+	return l.{{.method}}({{if .hasReq}}in{{if .stream}} ,stream{{end}}{{else}}{{if .stream}}stream{{end}}{{end}})
 }
 `
 )
@@ -91,6 +91,15 @@ func (g *DefaultGenerator) genFunctions(goPackage string, service parser.Service
 		}
 
 		comment := parser.GetComment(rpc.Doc())
+		var streamServer string
+		if rpc.StreamsRequest && rpc.StreamsReturns {
+			streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamServer")
+		} else if rpc.StreamsRequest {
+			streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamServer")
+		} else {
+			streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamServer")
+		}
+
 		buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
 			"server":     stringx.From(service.Name).ToCamel(),
 			"logicName":  fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel()),
@@ -99,6 +108,10 @@ func (g *DefaultGenerator) genFunctions(goPackage string, service parser.Service
 			"response":   fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
 			"hasComment": len(comment) > 0,
 			"comment":    comment,
+			"hasReq":     !rpc.StreamsRequest,
+			"stream":     rpc.StreamsRequest || rpc.StreamsReturns,
+			"notStream":  !rpc.StreamsRequest && !rpc.StreamsReturns,
+			"streamBody": streamServer,
 		})
 		if err != nil {
 			return nil, err

+ 6 - 0
tools/goctl/rpc/generator/test.proto

@@ -59,4 +59,10 @@ service Test_Service {
   rpc MapService (MapReq) returns (CommonReply);
   // case repeated
   rpc RepeatedService (RepeatedReq) returns (CommonReply);
+  // server stream
+  rpc ServerStream (Req) returns (stream Reply);
+  // client stream
+  rpc ClientStream (stream Req) returns (Reply);
+  // stream
+  rpc Stream(stream Req) returns (stream Reply);
 }

+ 16 - 0
tools/goctl/rpc/parser/stream.proto

@@ -0,0 +1,16 @@
+syntax = "proto3";
+
+package test;
+
+message Req{
+  string input = 1;
+}
+
+message Reply{
+  string output = 1;
+}
+service TestService{
+  rpc ServerStream (Req) returns (stream Reply);
+  rpc ClientStream (stream Req) returns (Reply);
+  rpc Stream (stream Req) returns (stream Reply);
+}