genserver.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. package generator
  2. import (
  3. _ "embed"
  4. "fmt"
  5. "path/filepath"
  6. "strings"
  7. "github.com/wuntsong-org/go-zero-plus/core/collection"
  8. conf "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/config"
  9. "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/rpc/parser"
  10. "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util"
  11. "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/format"
  12. "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/pathx"
  13. "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/stringx"
  14. )
  15. const functionTemplate = `
  16. {{if .hasComment}}{{.comment}}{{end}}
  17. func (s *{{.server}}Server) {{.method}} ({{if .notStream}}ctx context.Context,{{if .hasReq}} in {{.request}}{{end}}{{else}}{{if .hasReq}} in {{.request}},{{end}}stream {{.streamBody}}{{end}}) ({{if .notStream}}{{.response}},{{end}}error) {
  18. l := {{.logicPkg}}.New{{.logicName}}({{if .notStream}}ctx,{{else}}stream.Context(),{{end}}s.svcCtx)
  19. return l.{{.method}}({{if .hasReq}}in{{if .stream}} ,stream{{end}}{{else}}{{if .stream}}stream{{end}}{{end}})
  20. }
  21. `
  22. //go:embed server.tpl
  23. var serverTemplate string
  24. // GenServer generates rpc server file, which is an implementation of rpc server
  25. func (g *Generator) GenServer(ctx DirContext, proto parser.Proto, cfg *conf.Config,
  26. c *ZRpcContext) error {
  27. if !c.Multiple {
  28. return g.genServerInCompatibility(ctx, proto, cfg, c)
  29. }
  30. return g.genServerGroup(ctx, proto, cfg)
  31. }
  32. func (g *Generator) genServerGroup(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
  33. dir := ctx.GetServer()
  34. for _, service := range proto.Service {
  35. var (
  36. serverFile string
  37. logicImport string
  38. )
  39. serverFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name+"_server")
  40. if err != nil {
  41. return err
  42. }
  43. serverChildPkg, err := dir.GetChildPackage(service.Name)
  44. if err != nil {
  45. return err
  46. }
  47. logicChildPkg, err := ctx.GetLogic().GetChildPackage(service.Name)
  48. if err != nil {
  49. return err
  50. }
  51. serverDir := filepath.Base(serverChildPkg)
  52. logicImport = fmt.Sprintf(`"%v"`, logicChildPkg)
  53. serverFile = filepath.Join(dir.Filename, serverDir, serverFilename+".go")
  54. svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
  55. pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
  56. imports := collection.NewSet()
  57. imports.AddStr(logicImport, svcImport, pbImport)
  58. head := util.GetHead(proto.Name)
  59. funcList, err := g.genFunctions(proto.PbPackage, service, true)
  60. if err != nil {
  61. return err
  62. }
  63. text, err := pathx.LoadTemplate(category, serverTemplateFile, serverTemplate)
  64. if err != nil {
  65. return err
  66. }
  67. notStream := false
  68. for _, rpc := range service.RPC {
  69. if !rpc.StreamsRequest && !rpc.StreamsReturns {
  70. notStream = true
  71. break
  72. }
  73. }
  74. if err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]any{
  75. "head": head,
  76. "unimplementedServer": fmt.Sprintf("%s.Unimplemented%sServer", proto.PbPackage,
  77. stringx.From(service.Name).ToCamel()),
  78. "server": stringx.From(service.Name).ToCamel(),
  79. "imports": strings.Join(imports.KeysStr(), pathx.NL),
  80. "funcs": strings.Join(funcList, pathx.NL),
  81. "notStream": notStream,
  82. }, serverFile, true); err != nil {
  83. return err
  84. }
  85. }
  86. return nil
  87. }
  88. func (g *Generator) genServerInCompatibility(ctx DirContext, proto parser.Proto,
  89. cfg *conf.Config, c *ZRpcContext) error {
  90. dir := ctx.GetServer()
  91. logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package)
  92. svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
  93. pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
  94. imports := collection.NewSet()
  95. imports.AddStr(logicImport, svcImport, pbImport)
  96. head := util.GetHead(proto.Name)
  97. service := proto.Service[0]
  98. serverFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name+"_server")
  99. if err != nil {
  100. return err
  101. }
  102. serverFile := filepath.Join(dir.Filename, serverFilename+".go")
  103. funcList, err := g.genFunctions(proto.PbPackage, service, false)
  104. if err != nil {
  105. return err
  106. }
  107. text, err := pathx.LoadTemplate(category, serverTemplateFile, serverTemplate)
  108. if err != nil {
  109. return err
  110. }
  111. notStream := false
  112. for _, rpc := range service.RPC {
  113. if !rpc.StreamsRequest && !rpc.StreamsReturns {
  114. notStream = true
  115. break
  116. }
  117. }
  118. return util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]any{
  119. "head": head,
  120. "unimplementedServer": fmt.Sprintf("%s.Unimplemented%sServer", proto.PbPackage,
  121. stringx.From(service.Name).ToCamel()),
  122. "server": stringx.From(service.Name).ToCamel(),
  123. "imports": strings.Join(imports.KeysStr(), pathx.NL),
  124. "funcs": strings.Join(funcList, pathx.NL),
  125. "notStream": notStream,
  126. }, serverFile, true)
  127. }
  128. func (g *Generator) genFunctions(goPackage string, service parser.Service, multiple bool) ([]string, error) {
  129. var (
  130. functionList []string
  131. logicPkg string
  132. )
  133. for _, rpc := range service.RPC {
  134. text, err := pathx.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)
  135. if err != nil {
  136. return nil, err
  137. }
  138. var logicName string
  139. if !multiple {
  140. logicPkg = "logic"
  141. logicName = fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel())
  142. } else {
  143. nameJoin := fmt.Sprintf("%s_logic", service.Name)
  144. logicPkg = strings.ToLower(stringx.From(nameJoin).ToCamel())
  145. logicName = fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel())
  146. }
  147. comment := parser.GetComment(rpc.Doc())
  148. streamServer := fmt.Sprintf("%s.%s_%s%s", goPackage, parser.CamelCase(service.Name),
  149. parser.CamelCase(rpc.Name), "Server")
  150. buffer, err := util.With("func").Parse(text).Execute(map[string]any{
  151. "server": stringx.From(service.Name).ToCamel(),
  152. "logicName": logicName,
  153. "method": parser.CamelCase(rpc.Name),
  154. "request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
  155. "response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
  156. "hasComment": len(comment) > 0,
  157. "comment": comment,
  158. "hasReq": !rpc.StreamsRequest,
  159. "stream": rpc.StreamsRequest || rpc.StreamsReturns,
  160. "notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
  161. "streamBody": streamServer,
  162. "logicPkg": logicPkg,
  163. })
  164. if err != nil {
  165. return nil, err
  166. }
  167. functionList = append(functionList, buffer.String())
  168. }
  169. return functionList, nil
  170. }