genlogic.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package gogen
  2. import (
  3. "fmt"
  4. "path"
  5. "strconv"
  6. "strings"
  7. "github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
  8. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  9. "github.com/tal-tech/go-zero/tools/goctl/config"
  10. "github.com/tal-tech/go-zero/tools/goctl/util/format"
  11. "github.com/tal-tech/go-zero/tools/goctl/util/pathx"
  12. "github.com/tal-tech/go-zero/tools/goctl/vars"
  13. )
  14. const logicTemplate = `package {{.pkgName}}
  15. import (
  16. {{.imports}}
  17. )
  18. type {{.logic}} struct {
  19. logx.Logger
  20. ctx context.Context
  21. svcCtx *svc.ServiceContext
  22. }
  23. func New{{.logic}}(ctx context.Context, svcCtx *svc.ServiceContext) {{.logic}} {
  24. return {{.logic}}{
  25. Logger: logx.WithContext(ctx),
  26. ctx: ctx,
  27. svcCtx: svcCtx,
  28. }
  29. }
  30. func (l *{{.logic}}) {{.function}}({{.request}}) {{.responseType}} {
  31. // todo: add your logic here and delete this line
  32. {{.returnString}}
  33. }
  34. `
  35. func genLogic(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
  36. for _, g := range api.Service.Groups {
  37. for _, r := range g.Routes {
  38. err := genLogicByRoute(dir, rootPkg, cfg, g, r)
  39. if err != nil {
  40. return err
  41. }
  42. }
  43. }
  44. return nil
  45. }
  46. func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
  47. logic := getLogicName(route)
  48. goFile, err := format.FileNamingFormat(cfg.NamingFormat, logic)
  49. if err != nil {
  50. return err
  51. }
  52. imports := genLogicImports(route, rootPkg)
  53. var responseString string
  54. var returnString string
  55. var requestString string
  56. if len(route.ResponseTypeName()) > 0 {
  57. resp := responseGoTypeName(route, typesPacket)
  58. responseString = "(resp " + resp + ", err error)"
  59. returnString = "return"
  60. } else {
  61. responseString = "error"
  62. returnString = "return nil"
  63. }
  64. if len(route.RequestTypeName()) > 0 {
  65. requestString = "req " + requestGoTypeName(route, typesPacket)
  66. }
  67. subDir := getLogicFolderPath(group, route)
  68. return genFile(fileGenConfig{
  69. dir: dir,
  70. subdir: subDir,
  71. filename: goFile + ".go",
  72. templateName: "logicTemplate",
  73. category: category,
  74. templateFile: logicTemplateFile,
  75. builtinTemplate: logicTemplate,
  76. data: map[string]string{
  77. "pkgName": subDir[strings.LastIndex(subDir, "/")+1:],
  78. "imports": imports,
  79. "logic": strings.Title(logic),
  80. "function": strings.Title(strings.TrimSuffix(logic, "Logic")),
  81. "responseType": responseString,
  82. "returnString": returnString,
  83. "request": requestString,
  84. },
  85. })
  86. }
  87. func getLogicFolderPath(group spec.Group, route spec.Route) string {
  88. folder := route.GetAnnotation(groupProperty)
  89. if len(folder) == 0 {
  90. folder = group.GetAnnotation(groupProperty)
  91. if len(folder) == 0 {
  92. return logicDir
  93. }
  94. }
  95. folder = strings.TrimPrefix(folder, "/")
  96. folder = strings.TrimSuffix(folder, "/")
  97. return path.Join(logicDir, folder)
  98. }
  99. func genLogicImports(route spec.Route, parentPkg string) string {
  100. var imports []string
  101. imports = append(imports, `"context"`+"\n")
  102. imports = append(imports, fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, contextDir)))
  103. if shallImportTypesPackage(route) {
  104. imports = append(imports, fmt.Sprintf("\"%s\"\n", pathx.JoinPackages(parentPkg, typesDir)))
  105. }
  106. imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL))
  107. return strings.Join(imports, "\n\t")
  108. }
  109. func onlyPrimitiveTypes(val string) bool {
  110. fields := strings.FieldsFunc(val, func(r rune) bool {
  111. return r == '[' || r == ']' || r == ' '
  112. })
  113. for _, field := range fields {
  114. if field == "map" {
  115. continue
  116. }
  117. // ignore array dimension number, like [5]int
  118. if _, err := strconv.Atoi(field); err == nil {
  119. continue
  120. }
  121. if !api.IsBasicType(field) {
  122. return false
  123. }
  124. }
  125. return true
  126. }
  127. func shallImportTypesPackage(route spec.Route) bool {
  128. if len(route.RequestTypeName()) > 0 {
  129. return true
  130. }
  131. respTypeName := route.ResponseTypeName()
  132. if len(respTypeName) == 0 {
  133. return false
  134. }
  135. if onlyPrimitiveTypes(respTypeName) {
  136. return false
  137. }
  138. return true
  139. }