genlogic.go 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package gen
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  8. "github.com/tal-tech/go-zero/tools/goctl/util"
  9. )
  10. const (
  11. logicTemplate = `package logic
  12. import (
  13. "context"
  14. {{.imports}}
  15. "github.com/tal-tech/go-zero/core/logx"
  16. )
  17. type {{.logicName}} struct {
  18. ctx context.Context
  19. logx.Logger
  20. }
  21. func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
  22. return &{{.logicName}}{
  23. ctx: ctx,
  24. Logger: logx.WithContext(ctx),
  25. }
  26. }
  27. {{.functions}}
  28. `
  29. logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
  30. func (l *{{.logicName}}) {{.method}} (in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) {
  31. // todo: add your logic here and delete this line
  32. return &{{.package}}.{{.response}}{}, nil
  33. }
  34. `
  35. )
  36. func (g *defaultRpcGenerator) genLogic() error {
  37. logicPath := g.dirM[dirLogic]
  38. protoPkg := g.ast.Package
  39. service := g.ast.Service
  40. for _, item := range service {
  41. for _, method := range item.Funcs {
  42. logicName := fmt.Sprintf("%slogic.go", method.Name.Lower())
  43. filename := filepath.Join(logicPath, logicName)
  44. functions, err := genLogicFunction(protoPkg, method)
  45. if err != nil {
  46. return err
  47. }
  48. imports := collection.NewSet()
  49. pbImport := fmt.Sprintf(`%v "%v"`, protoPkg, g.mustGetPackage(dirPb))
  50. svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
  51. imports.AddStr(pbImport, svcImport)
  52. err = util.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{
  53. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  54. "functions": functions,
  55. "imports": strings.Join(imports.KeysStr(), "\n"),
  56. }, filename, false)
  57. if err != nil {
  58. return err
  59. }
  60. }
  61. }
  62. return nil
  63. }
  64. func genLogicFunction(packageName string, method *parser.Func) (string, error) {
  65. var functions = make([]string, 0)
  66. buffer, err := util.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{
  67. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  68. "method": method.Name.Title(),
  69. "package": packageName,
  70. "request": method.InType,
  71. "response": method.OutType,
  72. "hasComment": len(method.Document) > 0,
  73. "comment": strings.Join(method.Document, "\n"),
  74. })
  75. if err != nil {
  76. return "", err
  77. }
  78. functions = append(functions, buffer.String())
  79. return strings.Join(functions, "\n"), nil
  80. }