genlogic.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. package gen
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  8. "github.com/tal-tech/go-zero/tools/goctl/util"
  9. )
  10. const (
  11. logicTemplate = `package logic
  12. import (
  13. "context"
  14. {{.imports}}
  15. "github.com/tal-tech/go-zero/core/logx"
  16. )
  17. type {{.logicName}} struct {
  18. ctx context.Context
  19. svcCtx *svc.ServiceContext
  20. logx.Logger
  21. }
  22. func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
  23. return &{{.logicName}}{
  24. ctx: ctx,
  25. svcCtx: svcCtx,
  26. Logger: logx.WithContext(ctx),
  27. }
  28. }
  29. {{.functions}}
  30. `
  31. logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
  32. func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
  33. // todo: add your logic here and delete this line
  34. return &{{.responseType}}{}, nil
  35. }
  36. `
  37. )
  38. func (g *defaultRpcGenerator) genLogic() error {
  39. logicPath := g.dirM[dirLogic]
  40. protoPkg := g.ast.Package
  41. service := g.ast.Service
  42. for _, item := range service {
  43. for _, method := range item.Funcs {
  44. logicName := fmt.Sprintf("%slogic.go", method.Name.Lower())
  45. filename := filepath.Join(logicPath, logicName)
  46. functions, importList, err := g.genLogicFunction(protoPkg, method)
  47. if err != nil {
  48. return err
  49. }
  50. imports := collection.NewSet()
  51. svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
  52. imports.AddStr(svcImport)
  53. imports.AddStr(importList...)
  54. text, err := util.LoadTemplate(category, logicTemplateFileFile, logicTemplate)
  55. if err != nil {
  56. return err
  57. }
  58. err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
  59. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  60. "functions": functions,
  61. "imports": strings.Join(imports.KeysStr(), util.NL),
  62. }, filename, false)
  63. if err != nil {
  64. return err
  65. }
  66. }
  67. }
  68. return nil
  69. }
  70. func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parser.Func) (string, []string, error) {
  71. var functions = make([]string, 0)
  72. var imports = collection.NewSet()
  73. if method.ParameterIn.Package == packageName || method.ParameterOut.Package == packageName {
  74. imports.AddStr(fmt.Sprintf(`%v "%v"`, packageName, g.mustGetPackage(dirPb)))
  75. }
  76. imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
  77. imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
  78. text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
  79. if err != nil {
  80. return "", nil, err
  81. }
  82. buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
  83. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  84. "method": method.Name.Title(),
  85. "request": method.ParameterIn.StarExpression,
  86. "response": method.ParameterOut.StarExpression,
  87. "responseType": method.ParameterOut.Expression,
  88. "hasComment": method.HaveDoc(),
  89. "comment": method.GetDoc(),
  90. })
  91. if err != nil {
  92. return "", nil, err
  93. }
  94. functions = append(functions, buffer.String())
  95. return strings.Join(functions, util.NL), imports.KeysStr(), nil
  96. }