genlogic.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package gogen
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  8. "github.com/tal-tech/go-zero/tools/goctl/util"
  9. )
  10. const (
  11. logicTemplate = `package logic
  12. import (
  13. "context"
  14. {{.imports}}
  15. "github.com/tal-tech/go-zero/core/logx"
  16. )
  17. type {{.logicName}} struct {
  18. ctx context.Context
  19. logx.Logger
  20. }
  21. func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
  22. return &{{.logicName}}{
  23. ctx: ctx,
  24. Logger: logx.WithContext(ctx),
  25. }
  26. }
  27. {{.functions}}
  28. `
  29. logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
  30. func (l *{{.logicName}}) {{.method}} (in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) {
  31. var resp {{.package}}.{{.response}}
  32. // todo: add your logic here and delete this line
  33. return &resp,nil
  34. }
  35. `
  36. )
  37. func (g *defaultRpcGenerator) genLogic() error {
  38. logicPath := g.dirM[dirLogic]
  39. protoPkg := g.ast.Package
  40. service := g.ast.Service
  41. for _, item := range service {
  42. for _, method := range item.Funcs {
  43. logicName := fmt.Sprintf("%slogic.go", method.Name.Lower())
  44. filename := filepath.Join(logicPath, logicName)
  45. functions, err := genLogicFunction(protoPkg, method)
  46. if err != nil {
  47. return err
  48. }
  49. imports := collection.NewSet()
  50. pbImport := fmt.Sprintf(`%v "%v"`, protoPkg, g.mustGetPackage(dirPb))
  51. svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
  52. imports.AddStr(pbImport, svcImport)
  53. err = util.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{
  54. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  55. "functions": functions,
  56. "imports": strings.Join(imports.KeysStr(), "\n"),
  57. }, filename, false)
  58. if err != nil {
  59. return err
  60. }
  61. }
  62. }
  63. return nil
  64. }
  65. func genLogicFunction(packageName string, method *parser.Func) (string, error) {
  66. var functions = make([]string, 0)
  67. buffer, err := util.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{
  68. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  69. "method": method.Name.Title(),
  70. "package": packageName,
  71. "request": method.InType,
  72. "response": method.OutType,
  73. "hasComment": len(method.Document) > 0,
  74. "comment": strings.Join(method.Document, "\n"),
  75. })
  76. if err != nil {
  77. return "", err
  78. }
  79. functions = append(functions, buffer.String())
  80. return strings.Join(functions, "\n"), nil
  81. }