gencall.go 9.1 KB


  1. package generator
  2. import (
  3. _ "embed"
  4. "fmt"
  5. "path/filepath"
  6. "sort"
  7. "strings"
  8. "github.com/emicklei/proto"
  9. "github.com/wuntsong-org/go-zero-plus/core/collection"
  10. conf "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/config"
  11. "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/rpc/parser"
  12. "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util"
  13. "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/format"
  14. "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/pathx"
  15. "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/stringx"
  16. )
  17. const (
  18. callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
  19. {{end}}{{.method}}(ctx context.Context{{if .hasReq}}, in *{{.pbRequest}}{{end}}, opts ...grpc.CallOption) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error)`
  20. callFunctionTemplate = `
  21. {{if .hasComment}}{{.comment}}{{end}}
  22. func (m *default{{.serviceName}}) {{.method}}(ctx context.Context{{if .hasReq}}, in *{{.pbRequest}}{{end}}, opts ...grpc.CallOption) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error) {
  23. client := {{if .isCallPkgSameToGrpcPkg}}{{else}}{{.package}}.{{end}}New{{.rpcServiceName}}Client(m.cli.Conn())
  24. return client.{{.method}}(ctx{{if .hasReq}}, in{{end}}, opts...)
  25. }
  26. `
  27. )
  28. //go:embed call.tpl
  29. var callTemplateText string
  30. // GenCall generates the rpc client code, which is the entry point for the rpc service call.
  31. // It is a layer of encapsulation for the rpc client and shields the details in the pb.
  32. func (g *Generator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config,
  33. c *ZRpcContext) error {
  34. if !c.Multiple {
  35. return g.genCallInCompatibility(ctx, proto, cfg)
  36. }
  37. return g.genCallGroup(ctx, proto, cfg)
  38. }
  39. func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
  40. dir := ctx.GetCall()
  41. head := util.GetHead(proto.Name)
  42. for _, service := range proto.Service {
  43. childPkg, err := dir.GetChildPackage(service.Name)
  44. if err != nil {
  45. return err
  46. }
  47. callFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name)
  48. if err != nil {
  49. return err
  50. }
  51. childDir := filepath.Base(childPkg)
  52. filename := filepath.Join(dir.Filename, childDir, fmt.Sprintf("%s.go", callFilename))
  53. isCallPkgSameToPbPkg := childDir == ctx.GetProtoGo().Filename
  54. isCallPkgSameToGrpcPkg := childDir == ctx.GetProtoGo().Filename
  55. serviceName := stringx.From(service.Name).ToCamel()
  56. alias := collection.NewSet()
  57. var hasSameNameBetweenMessageAndService bool
  58. for _, item := range proto.Message {
  59. msgName := getMessageName(*item.Message)
  60. if serviceName == msgName {
  61. hasSameNameBetweenMessageAndService = true
  62. }
  63. if !isCallPkgSameToPbPkg {
  64. alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
  65. fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
  66. }
  67. }
  68. if hasSameNameBetweenMessageAndService {
  69. serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
  70. }
  71. functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
  72. if err != nil {
  73. return err
  74. }
  75. iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
  76. if err != nil {
  77. return err
  78. }
  79. text, err := pathx.LoadTemplate(category, callTemplateFile, callTemplateText)
  80. if err != nil {
  81. return err
  82. }
  83. pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
  84. protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
  85. if isCallPkgSameToGrpcPkg {
  86. pbPackage = ""
  87. protoGoPackage = ""
  88. }
  89. aliasKeys := alias.KeysStr()
  90. sort.Strings(aliasKeys)
  91. if err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]any{
  92. "name": callFilename,
  93. "alias": strings.Join(aliasKeys, pathx.NL),
  94. "head": head,
  95. "filePackage": childDir,
  96. "pbPackage": pbPackage,
  97. "protoGoPackage": protoGoPackage,
  98. "serviceName": serviceName,
  99. "functions": strings.Join(functions, pathx.NL),
  100. "interface": strings.Join(iFunctions, pathx.NL),
  101. }, filename, true); err != nil {
  102. return err
  103. }
  104. }
  105. return nil
  106. }
  107. func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
  108. cfg *conf.Config) error {
  109. dir := ctx.GetCall()
  110. service := proto.Service[0]
  111. head := util.GetHead(proto.Name)
  112. isCallPkgSameToPbPkg := ctx.GetCall().Filename == ctx.GetPb().Filename
  113. isCallPkgSameToGrpcPkg := ctx.GetCall().Filename == ctx.GetProtoGo().Filename
  114. callFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name)
  115. if err != nil {
  116. return err
  117. }
  118. serviceName := stringx.From(service.Name).ToCamel()
  119. alias := collection.NewSet()
  120. var hasSameNameBetweenMessageAndService bool
  121. for _, item := range proto.Message {
  122. msgName := getMessageName(*item.Message)
  123. if serviceName == msgName {
  124. hasSameNameBetweenMessageAndService = true
  125. }
  126. if !isCallPkgSameToPbPkg {
  127. alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
  128. fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
  129. }
  130. }
  131. if hasSameNameBetweenMessageAndService {
  132. serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
  133. }
  134. filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", callFilename))
  135. functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
  136. if err != nil {
  137. return err
  138. }
  139. iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
  140. if err != nil {
  141. return err
  142. }
  143. text, err := pathx.LoadTemplate(category, callTemplateFile, callTemplateText)
  144. if err != nil {
  145. return err
  146. }
  147. pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
  148. protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
  149. if isCallPkgSameToGrpcPkg {
  150. pbPackage = ""
  151. protoGoPackage = ""
  152. }
  153. aliasKeys := alias.KeysStr()
  154. sort.Strings(aliasKeys)
  155. return util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]any{
  156. "name": callFilename,
  157. "alias": strings.Join(aliasKeys, pathx.NL),
  158. "head": head,
  159. "filePackage": dir.Base,
  160. "pbPackage": pbPackage,
  161. "protoGoPackage": protoGoPackage,
  162. "serviceName": serviceName,
  163. "functions": strings.Join(functions, pathx.NL),
  164. "interface": strings.Join(iFunctions, pathx.NL),
  165. }, filename, true)
  166. }
  167. func getMessageName(msg proto.Message) string {
  168. list := []string{msg.Name}
  169. for {
  170. parent := msg.Parent
  171. if parent == nil {
  172. break
  173. }
  174. parentMsg, ok := parent.(*proto.Message)
  175. if !ok {
  176. break
  177. }
  178. tmp := []string{parentMsg.Name}
  179. list = append(tmp, list...)
  180. msg = *parentMsg
  181. }
  182. return strings.Join(list, "_")
  183. }
  184. func (g *Generator) genFunction(goPackage string, serviceName string, service parser.Service,
  185. isCallPkgSameToGrpcPkg bool) ([]string, error) {
  186. functions := make([]string, 0)
  187. for _, rpc := range service.RPC {
  188. text, err := pathx.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate)
  189. if err != nil {
  190. return nil, err
  191. }
  192. comment := parser.GetComment(rpc.Doc())
  193. streamServer := fmt.Sprintf("%s.%s_%s%s", goPackage, parser.CamelCase(service.Name),
  194. parser.CamelCase(rpc.Name), "Client")
  195. if isCallPkgSameToGrpcPkg {
  196. streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name),
  197. parser.CamelCase(rpc.Name), "Client")
  198. }
  199. buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]any{
  200. "serviceName": serviceName,
  201. "rpcServiceName": parser.CamelCase(service.Name),
  202. "method": parser.CamelCase(rpc.Name),
  203. "package": goPackage,
  204. "pbRequest": parser.CamelCase(rpc.RequestType),
  205. "pbResponse": parser.CamelCase(rpc.ReturnsType),
  206. "hasComment": len(comment) > 0,
  207. "comment": comment,
  208. "hasReq": !rpc.StreamsRequest,
  209. "notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
  210. "streamBody": streamServer,
  211. "isCallPkgSameToGrpcPkg": isCallPkgSameToGrpcPkg,
  212. })
  213. if err != nil {
  214. return nil, err
  215. }
  216. functions = append(functions, buffer.String())
  217. }
  218. return functions, nil
  219. }
  220. func (g *Generator) getInterfaceFuncs(goPackage string, service parser.Service,
  221. isCallPkgSameToGrpcPkg bool) ([]string, error) {
  222. functions := make([]string, 0)
  223. for _, rpc := range service.RPC {
  224. text, err := pathx.LoadTemplate(category, callInterfaceFunctionTemplateFile,
  225. callInterfaceFunctionTemplate)
  226. if err != nil {
  227. return nil, err
  228. }
  229. comment := parser.GetComment(rpc.Doc())
  230. streamServer := fmt.Sprintf("%s.%s_%s%s", goPackage, parser.CamelCase(service.Name),
  231. parser.CamelCase(rpc.Name), "Client")
  232. if isCallPkgSameToGrpcPkg {
  233. streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name),
  234. parser.CamelCase(rpc.Name), "Client")
  235. }
  236. buffer, err := util.With("interfaceFn").Parse(text).Execute(
  237. map[string]any{
  238. "hasComment": len(comment) > 0,
  239. "comment": comment,
  240. "method": parser.CamelCase(rpc.Name),
  241. "hasReq": !rpc.StreamsRequest,
  242. "pbRequest": parser.CamelCase(rpc.RequestType),
  243. "notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
  244. "pbResponse": parser.CamelCase(rpc.ReturnsType),
  245. "streamBody": streamServer,
  246. })
  247. if err != nil {
  248. return nil, err
  249. }
  250. functions = append(functions, buffer.String())
  251. }
  252. return functions, nil
  253. }