genlogic.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. package gen
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  8. "github.com/tal-tech/go-zero/tools/goctl/templatex"
  9. "github.com/tal-tech/go-zero/tools/goctl/util"
  10. )
  11. const (
  12. logicTemplate = `package logic
  13. import (
  14. "context"
  15. {{.imports}}
  16. "github.com/tal-tech/go-zero/core/logx"
  17. )
  18. type {{.logicName}} struct {
  19. ctx context.Context
  20. svcCtx *svc.ServiceContext
  21. logx.Logger
  22. }
  23. func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
  24. return &{{.logicName}}{
  25. ctx: ctx,
  26. svcCtx: svcCtx,
  27. Logger: logx.WithContext(ctx),
  28. }
  29. }
  30. {{.functions}}
  31. `
  32. logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
  33. func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
  34. // todo: add your logic here and delete this line
  35. return &{{.responseType}}{}, nil
  36. }
  37. `
  38. )
  39. func (g *defaultRpcGenerator) genLogic() error {
  40. logicPath := g.dirM[dirLogic]
  41. protoPkg := g.ast.Package
  42. service := g.ast.Service
  43. for _, item := range service {
  44. for _, method := range item.Funcs {
  45. logicName := fmt.Sprintf("%slogic.go", method.Name.Lower())
  46. filename := filepath.Join(logicPath, logicName)
  47. functions, importList, err := g.genLogicFunction(protoPkg, method)
  48. if err != nil {
  49. return err
  50. }
  51. imports := collection.NewSet()
  52. svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
  53. imports.AddStr(svcImport)
  54. imports.AddStr(importList...)
  55. err = templatex.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{
  56. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  57. "functions": functions,
  58. "imports": strings.Join(imports.KeysStr(), util.NL),
  59. }, filename, false)
  60. if err != nil {
  61. return err
  62. }
  63. }
  64. }
  65. return nil
  66. }
  67. func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parser.Func) (string, []string, error) {
  68. var functions = make([]string, 0)
  69. var imports = collection.NewSet()
  70. if method.ParameterIn.Package == packageName || method.ParameterOut.Package == packageName {
  71. imports.AddStr(fmt.Sprintf(`%v "%v"`, packageName, g.mustGetPackage(dirPb)))
  72. }
  73. imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
  74. imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
  75. buffer, err := templatex.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{
  76. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  77. "method": method.Name.Title(),
  78. "request": method.ParameterIn.StarExpression,
  79. "response": method.ParameterOut.StarExpression,
  80. "responseType": method.ParameterOut.Expression,
  81. "hasComment": method.HaveDoc(),
  82. "comment": method.GetDoc(),
  83. })
  84. if err != nil {
  85. return "", nil, err
  86. }
  87. functions = append(functions, buffer.String())
  88. return strings.Join(functions, util.NL), imports.KeysStr(), nil
  89. }