genserver.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package gen
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  8. "github.com/tal-tech/go-zero/tools/goctl/templatex"
  9. "github.com/tal-tech/go-zero/tools/goctl/util"
  10. )
  11. const (
  12. serverTemplate = `{{.head}}
  13. package server
  14. import (
  15. "context"
  16. {{.imports}}
  17. )
  18. type {{.types}}
  19. func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
  20. return &{{.server}}Server{
  21. svcCtx: svcCtx,
  22. }
  23. }
  24. {{.funcs}}
  25. `
  26. functionTemplate = `
  27. {{if .hasComment}}{{.comment}}{{end}}
  28. func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
  29. l := logic.New{{.logicName}}(ctx,s.svcCtx)
  30. return l.{{.method}}(in)
  31. }
  32. `
  33. typeFmt = `%sServer struct {
  34. svcCtx *svc.ServiceContext
  35. }`
  36. )
  37. func (g *defaultRpcGenerator) genHandler() error {
  38. serverPath := g.dirM[dirServer]
  39. file := g.ast
  40. logicImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic))
  41. svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
  42. imports := collection.NewSet()
  43. imports.AddStr(logicImport, svcImport)
  44. head := templatex.GetHead(g.Ctx.ProtoSource)
  45. for _, service := range file.Service {
  46. filename := fmt.Sprintf("%vserver.go", service.Name.Lower())
  47. serverFile := filepath.Join(serverPath, filename)
  48. funcList, importList, err := g.genFunctions(service)
  49. if err != nil {
  50. return err
  51. }
  52. imports.AddStr(importList...)
  53. err = templatex.With("server").GoFmt(true).Parse(serverTemplate).SaveTo(map[string]interface{}{
  54. "head": head,
  55. "types": fmt.Sprintf(typeFmt, service.Name.Title()),
  56. "server": service.Name.Title(),
  57. "imports": strings.Join(imports.KeysStr(), util.NL),
  58. "funcs": strings.Join(funcList, util.NL),
  59. }, serverFile, true)
  60. if err != nil {
  61. return err
  62. }
  63. }
  64. return nil
  65. }
  66. func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string, []string, error) {
  67. file := g.ast
  68. pkg := file.Package
  69. var functionList []string
  70. imports := collection.NewSet()
  71. for _, method := range service.Funcs {
  72. if method.ParameterIn.Package == pkg || method.ParameterOut.Package == pkg {
  73. imports.AddStr(fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb)))
  74. }
  75. imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
  76. imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
  77. buffer, err := templatex.With("func").Parse(functionTemplate).Execute(map[string]interface{}{
  78. "server": service.Name.Title(),
  79. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  80. "method": method.Name.Title(),
  81. "package": pkg,
  82. "request": method.ParameterIn.StarExpression,
  83. "response": method.ParameterOut.StarExpression,
  84. "hasComment": method.HaveDoc(),
  85. "comment": method.GetDoc(),
  86. })
  87. if err != nil {
  88. return nil, nil, err
  89. }
  90. functionList = append(functionList, buffer.String())
  91. }
  92. return functionList, imports.KeysStr(), nil
  93. }