gencall.go 5.4 KB


  1. package gen
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  8. "github.com/tal-tech/go-zero/tools/goctl/util"
  9. )
  10. const (
  11. typesFilename = "types.go"
  12. callTemplateText = `{{.head}}
  13. //go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
  14. package {{.filePackage}}
  15. import (
  16. "context"
  17. {{.package}}
  18. "github.com/tal-tech/go-zero/core/jsonx"
  19. "github.com/tal-tech/go-zero/zrpc"
  20. )
  21. type (
  22. {{.serviceName}} interface {
  23. {{.interface}}
  24. }
  25. default{{.serviceName}} struct {
  26. cli zrpc.Client
  27. }
  28. )
  29. func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
  30. return &default{{.serviceName}}{
  31. cli: cli,
  32. }
  33. }
  34. {{.functions}}
  35. `
  36. callTemplateTypes = `{{.head}}
  37. package {{.filePackage}}
  38. import "errors"
  39. var errJsonConvert = errors.New("json convert error")
  40. {{.const}}
  41. {{.types}}
  42. `
  43. callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
  44. {{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
  45. callFunctionTemplate = `
  46. {{if .hasComment}}{{.comment}}{{end}}
  47. func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequestName}}) (*{{.pbResponse}}, error) {
  48. var request {{.pbRequest}}
  49. bts, err := jsonx.Marshal(in)
  50. if err != nil {
  51. return nil, errJsonConvert
  52. }
  53. err = jsonx.Unmarshal(bts, &request)
  54. if err != nil {
  55. return nil, errJsonConvert
  56. }
  57. client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
  58. resp, err := client.{{.method}}(ctx, &request)
  59. if err != nil{
  60. return nil, err
  61. }
  62. var ret {{.pbResponse}}
  63. bts, err = jsonx.Marshal(resp)
  64. if err != nil{
  65. return nil, errJsonConvert
  66. }
  67. err = jsonx.Unmarshal(bts, &ret)
  68. if err != nil{
  69. return nil, errJsonConvert
  70. }
  71. return &ret, nil
  72. }
  73. `
  74. )
  75. func (g *defaultRpcGenerator) genCall() error {
  76. file := g.ast
  77. if len(file.Service) == 0 {
  78. return nil
  79. }
  80. if len(file.Service) > 1 {
  81. return fmt.Errorf("we recommend only one service in a proto, currently %d", len(file.Service))
  82. }
  83. typeCode, err := file.GenTypesCode()
  84. if err != nil {
  85. return err
  86. }
  87. constLit, err := file.GenEnumCode()
  88. if err != nil {
  89. return err
  90. }
  91. service := file.Service[0]
  92. callPath := filepath.Join(g.dirM[dirTarget], service.Name.Lower())
  93. if err = util.MkdirIfNotExist(callPath); err != nil {
  94. return err
  95. }
  96. filename := filepath.Join(callPath, typesFilename)
  97. head := util.GetHead(g.Ctx.ProtoSource)
  98. text, err := util.LoadTemplate(category, callTypesTemplateFile, callTemplateTypes)
  99. if err != nil {
  100. return err
  101. }
  102. err = util.With("types").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
  103. "head": head,
  104. "const": constLit,
  105. "filePackage": service.Name.Lower(),
  106. "serviceName": g.Ctx.ServiceName.Title(),
  107. "lowerStartServiceName": g.Ctx.ServiceName.UnTitle(),
  108. "types": typeCode,
  109. }, filename, true)
  110. if err != nil {
  111. return err
  112. }
  113. filename = filepath.Join(callPath, fmt.Sprintf("%s.go", service.Name.Lower()))
  114. functions, importList, err := g.genFunction(service)
  115. if err != nil {
  116. return err
  117. }
  118. iFunctions, err := g.getInterfaceFuncs(service)
  119. if err != nil {
  120. return err
  121. }
  122. text, err = util.LoadTemplate(category, callTemplateFile, callTemplateText)
  123. if err != nil {
  124. return err
  125. }
  126. err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
  127. "name": service.Name.Lower(),
  128. "head": head,
  129. "filePackage": service.Name.Lower(),
  130. "package": strings.Join(importList, util.NL),
  131. "serviceName": service.Name.Title(),
  132. "functions": strings.Join(functions, util.NL),
  133. "interface": strings.Join(iFunctions, util.NL),
  134. }, filename, true)
  135. return err
  136. }
  137. func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string, []string, error) {
  138. file := g.ast
  139. pkgName := file.Package
  140. functions := make([]string, 0)
  141. imports := collection.NewSet()
  142. imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb)))
  143. for _, method := range service.Funcs {
  144. imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
  145. text, err := util.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate)
  146. if err != nil {
  147. return nil, nil, err
  148. }
  149. buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
  150. "rpcServiceName": service.Name.Title(),
  151. "method": method.Name.Title(),
  152. "package": pkgName,
  153. "pbRequestName": method.ParameterIn.Name,
  154. "pbRequest": method.ParameterIn.Expression,
  155. "pbResponse": method.ParameterOut.Name,
  156. "hasComment": method.HaveDoc(),
  157. "comment": method.GetDoc(),
  158. })
  159. if err != nil {
  160. return nil, nil, err
  161. }
  162. functions = append(functions, buffer.String())
  163. }
  164. return functions, imports.KeysStr(), nil
  165. }
  166. func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) {
  167. functions := make([]string, 0)
  168. for _, method := range service.Funcs {
  169. text, err := util.LoadTemplate(category, callInterfaceFunctionTemplateFile, callInterfaceFunctionTemplate)
  170. if err != nil {
  171. return nil, err
  172. }
  173. buffer, err := util.With("interfaceFn").Parse(text).Execute(
  174. map[string]interface{}{
  175. "hasComment": method.HaveDoc(),
  176. "comment": method.GetDoc(),
  177. "method": method.Name.Title(),
  178. "pbRequest": method.ParameterIn.Name,
  179. "pbResponse": method.ParameterOut.Name,
  180. })
  181. if err != nil {
  182. return nil, err
  183. }
  184. functions = append(functions, buffer.String())
  185. }
  186. return functions, nil
  187. }