genserver.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package generator
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  8. "github.com/tal-tech/go-zero/tools/goctl/util"
  9. "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
  10. )
  11. const (
  12. serverTemplate = `{{.head}}
  13. package server
  14. import (
  15. "context"
  16. {{.imports}}
  17. )
  18. type {{.server}}Server struct {
  19. svcCtx *svc.ServiceContext
  20. }
  21. func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
  22. return &{{.server}}Server{
  23. svcCtx: svcCtx,
  24. }
  25. }
  26. {{.funcs}}
  27. `
  28. functionTemplate = `
  29. {{if .hasComment}}{{.comment}}{{end}}
  30. func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
  31. l := logic.New{{.logicName}}(ctx,s.svcCtx)
  32. return l.{{.method}}(in)
  33. }
  34. `
  35. )
  36. func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error {
  37. dir := ctx.GetServer()
  38. logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package)
  39. svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
  40. pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
  41. imports := collection.NewSet()
  42. imports.AddStr(logicImport, svcImport, pbImport)
  43. head := util.GetHead(proto.Name)
  44. service := proto.Service
  45. serverFile := filepath.Join(dir.Filename, formatFilename(service.Name+"_server")+".go")
  46. funcList, err := g.genFunctions(proto.PbPackage, service)
  47. if err != nil {
  48. return err
  49. }
  50. text, err := util.LoadTemplate(category, serverTemplateFile, serverTemplate)
  51. if err != nil {
  52. return err
  53. }
  54. err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
  55. "head": head,
  56. "server": stringx.From(service.Name).Title(),
  57. "imports": strings.Join(imports.KeysStr(), util.NL),
  58. "funcs": strings.Join(funcList, util.NL),
  59. }, serverFile, true)
  60. return err
  61. }
  62. func (g *defaultGenerator) genFunctions(goPackage string, service parser.Service) ([]string, error) {
  63. var functionList []string
  64. for _, rpc := range service.RPC {
  65. text, err := util.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)
  66. if err != nil {
  67. return nil, err
  68. }
  69. comment := parser.GetComment(rpc.Doc())
  70. buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
  71. "server": stringx.From(service.Name).Title(),
  72. "logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).Title()),
  73. "method": parser.CamelCase(rpc.Name),
  74. "request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
  75. "response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
  76. "hasComment": len(comment) > 0,
  77. "comment": comment,
  78. })
  79. if err != nil {
  80. return nil, err
  81. }
  82. functionList = append(functionList, buffer.String())
  83. }
  84. return functionList, nil
  85. }