genlogic.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. package gogen
  2. import (
  3. _ "embed"
  4. "fmt"
  5. "path"
  6. "strconv"
  7. "strings"
  8. "github.com/zeromicro/go-zero/tools/goctl/api/parser/g4/gen/api"
  9. "github.com/zeromicro/go-zero/tools/goctl/api/spec"
  10. "github.com/zeromicro/go-zero/tools/goctl/config"
  11. "github.com/zeromicro/go-zero/tools/goctl/util/format"
  12. "github.com/zeromicro/go-zero/tools/goctl/util/pathx"
  13. "github.com/zeromicro/go-zero/tools/goctl/vars"
  14. )
  15. //go:embed logic.tpl
  16. var logicTemplate string
  17. func genLogic(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
  18. for _, g := range api.Service.Groups {
  19. for _, r := range g.Routes {
  20. err := genLogicByRoute(dir, rootPkg, cfg, g, r)
  21. if err != nil {
  22. return err
  23. }
  24. }
  25. }
  26. return nil
  27. }
  28. func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
  29. logic := getLogicName(route)
  30. goFile, err := format.FileNamingFormat(cfg.NamingFormat, logic)
  31. if err != nil {
  32. return err
  33. }
  34. imports := genLogicImports(route, rootPkg)
  35. var responseString string
  36. var returnString string
  37. var requestString string
  38. if len(route.ResponseTypeName()) > 0 {
  39. resp := responseGoTypeName(route, typesPacket)
  40. responseString = "(resp " + resp + ", err error)"
  41. returnString = "return"
  42. } else {
  43. responseString = "error"
  44. returnString = "return nil"
  45. }
  46. if len(route.RequestTypeName()) > 0 {
  47. requestString = "req *" + requestGoTypeName(route, typesPacket)
  48. }
  49. subDir := getLogicFolderPath(group, route)
  50. return genFile(fileGenConfig{
  51. dir: dir,
  52. subdir: subDir,
  53. filename: goFile + ".go",
  54. templateName: "logicTemplate",
  55. category: category,
  56. templateFile: logicTemplateFile,
  57. builtinTemplate: logicTemplate,
  58. data: map[string]string{
  59. "pkgName": subDir[strings.LastIndex(subDir, "/")+1:],
  60. "imports": imports,
  61. "logic": strings.Title(logic),
  62. "function": strings.Title(strings.TrimSuffix(logic, "Logic")),
  63. "responseType": responseString,
  64. "returnString": returnString,
  65. "request": requestString,
  66. },
  67. })
  68. }
  69. func getLogicFolderPath(group spec.Group, route spec.Route) string {
  70. folder := route.GetAnnotation(groupProperty)
  71. if len(folder) == 0 {
  72. folder = group.GetAnnotation(groupProperty)
  73. if len(folder) == 0 {
  74. return logicDir
  75. }
  76. }
  77. folder = strings.TrimPrefix(folder, "/")
  78. folder = strings.TrimSuffix(folder, "/")
  79. return path.Join(logicDir, folder)
  80. }
  81. func genLogicImports(route spec.Route, parentPkg string) string {
  82. var imports []string
  83. imports = append(imports, `"context"`+"\n")
  84. imports = append(imports, fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, contextDir)))
  85. if shallImportTypesPackage(route) {
  86. imports = append(imports, fmt.Sprintf("\"%s\"\n", pathx.JoinPackages(parentPkg, typesDir)))
  87. }
  88. imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL))
  89. return strings.Join(imports, "\n\t")
  90. }
  91. func onlyPrimitiveTypes(val string) bool {
  92. fields := strings.FieldsFunc(val, func(r rune) bool {
  93. return r == '[' || r == ']' || r == ' '
  94. })
  95. for _, field := range fields {
  96. if field == "map" {
  97. continue
  98. }
  99. // ignore array dimension number, like [5]int
  100. if _, err := strconv.Atoi(field); err == nil {
  101. continue
  102. }
  103. if !api.IsBasicType(field) {
  104. return false
  105. }
  106. }
  107. return true
  108. }
  109. func shallImportTypesPackage(route spec.Route) bool {
  110. if len(route.RequestTypeName()) > 0 {
  111. return true
  112. }
  113. respTypeName := route.ResponseTypeName()
  114. if len(respTypeName) == 0 {
  115. return false
  116. }
  117. if onlyPrimitiveTypes(respTypeName) {
  118. return false
  119. }
  120. return true
  121. }