genlogic.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "path"
  6. "strings"
  7. "text/template"
  8. "zero/tools/goctl/api/spec"
  9. "zero/tools/goctl/api/util"
  10. "zero/tools/goctl/vars"
  11. )
  12. const logicTemplate = `package logic
  13. import (
  14. {{.imports}}
  15. )
  16. type {{.logic}} struct {
  17. ctx context.Context
  18. logx.Logger
  19. }
  20. func New{{.logic}}(ctx context.Context, svcCtx *svc.ServiceContext) {{.logic}} {
  21. return {{.logic}}{
  22. ctx: ctx,
  23. Logger: logx.WithContext(ctx),
  24. }
  25. // TODO need set model here from svc
  26. }
  27. func (l *{{.logic}}) {{.function}}({{.request}}) {{.responseType}} {
  28. {{.returnString}}
  29. }
  30. `
  31. func genLogic(dir string, api *spec.ApiSpec) error {
  32. for _, g := range api.Service.Groups {
  33. for _, r := range g.Routes {
  34. err := genLogicByRoute(dir, g, r)
  35. if err != nil {
  36. return err
  37. }
  38. }
  39. }
  40. return nil
  41. }
  42. func genLogicByRoute(dir string, group spec.Group, route spec.Route) error {
  43. handler, ok := util.GetAnnotationValue(route.Annotations, "server", "handler")
  44. if !ok {
  45. return fmt.Errorf("missing handler annotation for %q", route.Path)
  46. }
  47. handler = strings.TrimSuffix(handler, "handler")
  48. handler = strings.TrimSuffix(handler, "Handler")
  49. filename := strings.ToLower(handler)
  50. goFile := filename + "logic.go"
  51. fp, created, err := util.MaybeCreateFile(dir, getLogicFolderPath(group, route), goFile)
  52. if err != nil {
  53. return err
  54. }
  55. if !created {
  56. return nil
  57. }
  58. defer fp.Close()
  59. parentPkg, err := getParentPackage(dir)
  60. if err != nil {
  61. return err
  62. }
  63. imports := genLogicImports(route, parentPkg)
  64. responseString := ""
  65. returnString := ""
  66. requestString := ""
  67. if len(route.ResponseType.Name) > 0 {
  68. responseString = "(*types." + strings.Title(route.ResponseType.Name) + ", error)"
  69. returnString = "return nil, nil"
  70. } else {
  71. responseString = "error"
  72. returnString = "return nil"
  73. }
  74. if len(route.RequestType.Name) > 0 {
  75. requestString = "req " + "types." + strings.Title(route.RequestType.Name)
  76. }
  77. t := template.Must(template.New("logicTemplate").Parse(logicTemplate))
  78. buffer := new(bytes.Buffer)
  79. err = t.Execute(fp, map[string]string{
  80. "imports": imports,
  81. "logic": strings.Title(handler) + "Logic",
  82. "function": strings.Title(strings.TrimSuffix(handler, "Handler")),
  83. "responseType": responseString,
  84. "returnString": returnString,
  85. "request": requestString,
  86. })
  87. if err != nil {
  88. return nil
  89. }
  90. formatCode := formatCode(buffer.String())
  91. _, err = fp.WriteString(formatCode)
  92. return err
  93. }
  94. func getLogicFolderPath(group spec.Group, route spec.Route) string {
  95. folder, ok := util.GetAnnotationValue(route.Annotations, "server", folderProperty)
  96. if !ok {
  97. folder, ok = util.GetAnnotationValue(group.Annotations, "server", folderProperty)
  98. if !ok {
  99. return logicDir
  100. }
  101. }
  102. folder = strings.TrimPrefix(folder, "/")
  103. folder = strings.TrimSuffix(folder, "/")
  104. return path.Join(logicDir, folder)
  105. }
  106. func genLogicImports(route spec.Route, parentPkg string) string {
  107. var imports []string
  108. imports = append(imports, `"context"`)
  109. imports = append(imports, "\n")
  110. imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceUrl))
  111. if len(route.ResponseType.Name) > 0 || len(route.RequestType.Name) > 0 {
  112. imports = append(imports, fmt.Sprintf("\"%s\"", path.Join(parentPkg, typesDir)))
  113. }
  114. imports = append(imports, fmt.Sprintf("\"%s\"", path.Join(parentPkg, contextDir)))
  115. return strings.Join(imports, "\n\t")
  116. }