gencall.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package gen
  2. import (
  3. "fmt"
  4. "os"
  5. "os/exec"
  6. "path/filepath"
  7. "strings"
  8. "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
  9. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  10. "github.com/tal-tech/go-zero/tools/goctl/util"
  11. )
  12. const (
  13. callTemplateText = `{{.head}}
  14. //go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
  15. package {{.filePackage}}
  16. import (
  17. "context"
  18. {{.package}}
  19. "github.com/tal-tech/go-zero/core/jsonx"
  20. "github.com/tal-tech/go-zero/rpcx"
  21. )
  22. type (
  23. {{.serviceName}} interface {
  24. {{.interface}}
  25. }
  26. default{{.serviceName}} struct {
  27. cli rpcx.Client
  28. }
  29. )
  30. func New{{.serviceName}}(cli rpcx.Client) {{.serviceName}} {
  31. return &default{{.serviceName}}{
  32. cli: cli,
  33. }
  34. }
  35. {{.functions}}
  36. `
  37. callTemplateTypes = `{{.head}}
  38. package {{.filePackage}}
  39. import "errors"
  40. var errJsonConvert = errors.New("json convert error")
  41. {{.types}}
  42. `
  43. callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
  44. {{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) {{if .hasResponse}}(*{{.pbResponse}},{{end}} error{{if .hasResponse}}){{end}}`
  45. callFunctionTemplate = `
  46. {{if .hasComment}}{{.comment}}{{end}}
  47. func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) {{if .hasResponse}}(*{{.pbResponse}},{{end}} error{{if .hasResponse}}){{end}} {
  48. var request {{.package}}.{{.pbRequest}}
  49. bts, err := jsonx.Marshal(in)
  50. if err != nil {
  51. return {{if .hasResponse}}nil, {{end}}errJsonConvert
  52. }
  53. err = jsonx.Unmarshal(bts, &request)
  54. if err != nil {
  55. return {{if .hasResponse}}nil, {{end}}errJsonConvert
  56. }
  57. client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
  58. {{if .hasResponse}}resp, err := {{else}}_, err = {{end}}client.{{.method}}(ctx, &request)
  59. {{if .hasResponse}}if err != nil{
  60. return nil, err
  61. }
  62. var ret {{.pbResponse}}
  63. bts, err = jsonx.Marshal(resp)
  64. if err != nil{
  65. return nil, errJsonConvert
  66. }
  67. err = jsonx.Unmarshal(bts, &ret)
  68. if err != nil{
  69. return nil, errJsonConvert
  70. }
  71. return &ret, nil{{else}}if err != nil {
  72. return err
  73. }
  74. return nil{{end}}
  75. }
  76. `
  77. )
  78. func (g *defaultRpcGenerator) genCall() error {
  79. file := g.ast
  80. if len(file.Service) == 0 {
  81. return nil
  82. }
  83. if len(file.Service) > 1 {
  84. return fmt.Errorf("we recommend only one service in a proto, currently %d", len(file.Service))
  85. }
  86. typeCode, err := file.GenTypesCode()
  87. if err != nil {
  88. return err
  89. }
  90. service := file.Service[0]
  91. callPath, err := filepath.Abs(service.Name.Lower())
  92. if err != nil {
  93. return err
  94. }
  95. if err = util.MkdirIfNotExist(callPath); err != nil {
  96. return err
  97. }
  98. pbPkg := file.Package
  99. remotePackage := fmt.Sprintf(`%v "%v"`, pbPkg, g.mustGetPackage(dirPb))
  100. filename := filepath.Join(callPath, "types.go")
  101. head := util.GetHead(g.Ctx.ProtoSource)
  102. err = util.With("types").GoFmt(true).Parse(callTemplateTypes).SaveTo(map[string]interface{}{
  103. "head": head,
  104. "filePackage": service.Name.Lower(),
  105. "pbPkg": pbPkg,
  106. "serviceName": g.Ctx.ServiceName.Title(),
  107. "lowerStartServiceName": g.Ctx.ServiceName.UnTitle(),
  108. "types": typeCode,
  109. }, filename, true)
  110. if err != nil {
  111. return err
  112. }
  113. _, err = exec.LookPath("mockgen")
  114. mockGenInstalled := err == nil
  115. filename = filepath.Join(callPath, fmt.Sprintf("%s.go", service.Name.Lower()))
  116. functions, err := g.getFuncs(service)
  117. if err != nil {
  118. return err
  119. }
  120. iFunctions, err := g.getInterfaceFuncs(service)
  121. if err != nil {
  122. return err
  123. }
  124. mockFile := filepath.Join(callPath, fmt.Sprintf("%s_mock.go", service.Name.Lower()))
  125. os.Remove(mockFile)
  126. err = util.With("shared").GoFmt(true).Parse(callTemplateText).SaveTo(map[string]interface{}{
  127. "name": service.Name.Lower(),
  128. "head": head,
  129. "filePackage": service.Name.Lower(),
  130. "pbPkg": pbPkg,
  131. "package": remotePackage,
  132. "serviceName": service.Name.Title(),
  133. "functions": strings.Join(functions, "\n"),
  134. "interface": strings.Join(iFunctions, "\n"),
  135. }, filename, true)
  136. if err != nil {
  137. return err
  138. }
  139. // if mockgen is already installed, it will generate code of gomock for shared files
  140. _, err = exec.LookPath("mockgen")
  141. if mockGenInstalled {
  142. execx.Run(fmt.Sprintf("go generate %s", filename))
  143. }
  144. return nil
  145. }
  146. func (g *defaultRpcGenerator) getFuncs(service *parser.RpcService) ([]string, error) {
  147. file := g.ast
  148. pkgName := file.Package
  149. functions := make([]string, 0)
  150. for _, method := range service.Funcs {
  151. data, found := file.Strcuts[strings.ToLower(method.OutType)]
  152. if found {
  153. found = len(data.Field) > 0
  154. }
  155. var comment string
  156. if len(method.Document) > 0 {
  157. comment = method.Document[0]
  158. }
  159. buffer, err := util.With("sharedFn").Parse(callFunctionTemplate).Execute(map[string]interface{}{
  160. "rpcServiceName": service.Name.Title(),
  161. "method": method.Name.Title(),
  162. "package": pkgName,
  163. "pbRequest": method.InType,
  164. "pbResponse": method.OutType,
  165. "hasResponse": found,
  166. "hasComment": len(method.Document) > 0,
  167. "comment": comment,
  168. })
  169. if err != nil {
  170. return nil, err
  171. }
  172. functions = append(functions, buffer.String())
  173. }
  174. return functions, nil
  175. }
  176. func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) {
  177. file := g.ast
  178. functions := make([]string, 0)
  179. for _, method := range service.Funcs {
  180. data, found := file.Strcuts[strings.ToLower(method.OutType)]
  181. if found {
  182. found = len(data.Field) > 0
  183. }
  184. var comment string
  185. if len(method.Document) > 0 {
  186. comment = method.Document[0]
  187. }
  188. buffer, err := util.With("interfaceFn").Parse(callInterfaceFunctionTemplate).Execute(
  189. map[string]interface{}{
  190. "hasComment": len(method.Document) > 0,
  191. "comment": comment,
  192. "method": method.Name.Title(),
  193. "pbRequest": method.InType,
  194. "pbResponse": method.OutType,
  195. "hasResponse": found,
  196. })
  197. if err != nil {
  198. return nil, err
  199. }
  200. functions = append(functions, buffer.String())
  201. }
  202. return functions, nil
  203. }