genshared.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. package gen
  2. import (
  3. "fmt"
  4. "os"
  5. "os/exec"
  6. "path/filepath"
  7. "strings"
  8. "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
  9. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  10. "github.com/tal-tech/go-zero/tools/goctl/util"
  11. )
  12. const (
  13. sharedTemplateText = `{{.head}}
  14. //go:generate mockgen -destination ./{{.name}}model_mock.go -package {{.filePackage}} -source $GOFILE
  15. package {{.filePackage}}
  16. import (
  17. "context"
  18. {{.package}}
  19. "github.com/tal-tech/go-zero/core/jsonx"
  20. "github.com/tal-tech/go-zero/rpcx"
  21. )
  22. type (
  23. {{.serviceName}}Model interface {
  24. {{.interface}}
  25. }
  26. default{{.serviceName}}Model struct {
  27. cli rpcx.Client
  28. }
  29. )
  30. func New{{.serviceName}}Model(cli rpcx.Client) {{.serviceName}}Model {
  31. return &default{{.serviceName}}Model{
  32. cli: cli,
  33. }
  34. }
  35. {{.functions}}
  36. `
  37. sharedTemplateTypes = `{{.head}}
  38. package {{.filePackage}}
  39. import "errors"
  40. var errJsonConvert = errors.New("json convert error")
  41. {{.types}}
  42. `
  43. sharedInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
  44. {{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) {{if .hasResponse}}(*{{.pbResponse}},{{end}} error{{if .hasResponse}}){{end}}`
  45. sharedFunctionTemplate = `
  46. {{if .hasComment}}{{.comment}}{{end}}
  47. func (m *default{{.rpcServiceName}}Model) {{.method}}(ctx context.Context,in *{{.pbRequest}}) {{if .hasResponse}}(*{{.pbResponse}},{{end}} error{{if .hasResponse}}){{end}} {
  48. var request {{.package}}.{{.pbRequest}}
  49. bts, err := jsonx.Marshal(in)
  50. if err != nil {
  51. return {{if .hasResponse}}nil, {{end}}errJsonConvert
  52. }
  53. err = jsonx.Unmarshal(bts, &request)
  54. if err != nil {
  55. return {{if .hasResponse}}nil, {{end}}errJsonConvert
  56. }
  57. client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
  58. {{if .hasResponse}}resp, err := {{else}}_, err = {{end}}client.{{.method}}(ctx, &request)
  59. {{if .hasResponse}}if err != nil{
  60. return nil, err
  61. }
  62. var ret {{.pbResponse}}
  63. bts, err = jsonx.Marshal(resp)
  64. if err != nil{
  65. return nil, errJsonConvert
  66. }
  67. err = jsonx.Unmarshal(bts, &ret)
  68. if err != nil{
  69. return nil, errJsonConvert
  70. }
  71. return &ret, nil{{else}}if err != nil {
  72. return err
  73. }
  74. return nil{{end}}
  75. }
  76. `
  77. )
  78. func (g *defaultRpcGenerator) genShared() error {
  79. sharePackage := filepath.Base(g.Ctx.SharedDir)
  80. file := g.ast
  81. typeCode, err := file.GenTypesCode()
  82. if err != nil {
  83. return err
  84. }
  85. pbPkg := file.Package
  86. remotePackage := fmt.Sprintf(`%v "%v"`, pbPkg, g.mustGetPackage(dirPb))
  87. filename := filepath.Join(g.Ctx.SharedDir, "types.go")
  88. head := util.GetHead(g.Ctx.ProtoSource)
  89. err = util.With("types").GoFmt(true).Parse(sharedTemplateTypes).SaveTo(map[string]interface{}{
  90. "head": head,
  91. "filePackage": sharePackage,
  92. "pbPkg": pbPkg,
  93. "serviceName": g.Ctx.ServiceName.Title(),
  94. "lowerStartServiceName": g.Ctx.ServiceName.UnTitle(),
  95. "types": typeCode,
  96. }, filename, true)
  97. for _, service := range file.Service {
  98. filename := filepath.Join(g.Ctx.SharedDir, fmt.Sprintf("%smodel.go", service.Name.Lower()))
  99. functions, err := g.getFuncs(service)
  100. if err != nil {
  101. return err
  102. }
  103. iFunctions, err := g.getInterfaceFuncs(service)
  104. if err != nil {
  105. return err
  106. }
  107. mockFile := filepath.Join(g.Ctx.SharedDir, fmt.Sprintf("%smodel_mock.go", service.Name.Lower()))
  108. os.Remove(mockFile)
  109. err = util.With("shared").GoFmt(true).Parse(sharedTemplateText).SaveTo(map[string]interface{}{
  110. "name": service.Name.Lower(),
  111. "head": head,
  112. "filePackage": sharePackage,
  113. "pbPkg": pbPkg,
  114. "package": remotePackage,
  115. "serviceName": service.Name.Title(),
  116. "functions": strings.Join(functions, "\n"),
  117. "interface": strings.Join(iFunctions, "\n"),
  118. }, filename, true)
  119. if err != nil {
  120. return err
  121. }
  122. }
  123. // if mockgen is already installed, it will generate code of gomock for shared files
  124. _, err = exec.LookPath("mockgen")
  125. if err != nil {
  126. g.Ctx.Warning("warning:mockgen is not found")
  127. } else {
  128. execx.Run(fmt.Sprintf("cd %s \ngo generate", g.Ctx.SharedDir))
  129. }
  130. return nil
  131. }
  132. func (g *defaultRpcGenerator) getFuncs(service *parser.RpcService) ([]string, error) {
  133. file := g.ast
  134. pkgName := file.Package
  135. functions := make([]string, 0)
  136. for _, method := range service.Funcs {
  137. data, found := file.Strcuts[strings.ToLower(method.OutType)]
  138. if found {
  139. found = len(data.Field) > 0
  140. }
  141. var comment string
  142. if len(method.Document) > 0 {
  143. comment = method.Document[0]
  144. }
  145. buffer, err := util.With("sharedFn").Parse(sharedFunctionTemplate).Execute(map[string]interface{}{
  146. "rpcServiceName": service.Name.Title(),
  147. "method": method.Name.Title(),
  148. "package": pkgName,
  149. "pbRequest": method.InType,
  150. "pbResponse": method.OutType,
  151. "hasResponse": found,
  152. "hasComment": len(method.Document) > 0,
  153. "comment": comment,
  154. })
  155. if err != nil {
  156. return nil, err
  157. }
  158. functions = append(functions, buffer.String())
  159. }
  160. return functions, nil
  161. }
  162. func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) {
  163. file := g.ast
  164. functions := make([]string, 0)
  165. for _, method := range service.Funcs {
  166. data, found := file.Strcuts[strings.ToLower(method.OutType)]
  167. if found {
  168. found = len(data.Field) > 0
  169. }
  170. var comment string
  171. if len(method.Document) > 0 {
  172. comment = method.Document[0]
  173. }
  174. buffer, err := util.With("interfaceFn").Parse(sharedInterfaceFunctionTemplate).Execute(map[string]interface{}{
  175. "hasComment": len(method.Document) > 0,
  176. "comment": comment,
  177. "method": method.Name.Title(),
  178. "pbRequest": method.InType,
  179. "pbResponse": method.OutType,
  180. "hasResponse": found,
  181. })
  182. if err != nil {
  183. return nil, err
  184. }
  185. functions = append(functions, buffer.String())
  186. }
  187. return functions, nil
  188. }