util.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. goformat "go/format"
  6. "io"
  7. "path/filepath"
  8. "strings"
  9. "text/template"
  10. "github.com/tal-tech/go-zero/core/collection"
  11. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  12. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  13. "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
  14. "github.com/tal-tech/go-zero/tools/goctl/util/pathx"
  15. )
  16. type fileGenConfig struct {
  17. dir string
  18. subdir string
  19. filename string
  20. templateName string
  21. category string
  22. templateFile string
  23. builtinTemplate string
  24. data interface{}
  25. }
  26. func genFile(c fileGenConfig) error {
  27. fp, created, err := util.MaybeCreateFile(c.dir, c.subdir, c.filename)
  28. if err != nil {
  29. return err
  30. }
  31. if !created {
  32. return nil
  33. }
  34. defer fp.Close()
  35. var text string
  36. if len(c.category) == 0 || len(c.templateFile) == 0 {
  37. text = c.builtinTemplate
  38. } else {
  39. text, err = pathx.LoadTemplate(c.category, c.templateFile, c.builtinTemplate)
  40. if err != nil {
  41. return err
  42. }
  43. }
  44. t := template.Must(template.New(c.templateName).Parse(text))
  45. buffer := new(bytes.Buffer)
  46. err = t.Execute(buffer, c.data)
  47. if err != nil {
  48. return err
  49. }
  50. code := formatCode(buffer.String())
  51. _, err = fp.WriteString(code)
  52. return err
  53. }
  54. func getParentPackage(dir string) (string, error) {
  55. abs, err := filepath.Abs(dir)
  56. if err != nil {
  57. return "", err
  58. }
  59. projectCtx, err := ctx.Prepare(abs)
  60. if err != nil {
  61. return "", err
  62. }
  63. // fix https://github.com/zeromicro/go-zero/issues/1058
  64. wd := projectCtx.WorkDir
  65. d := projectCtx.Dir
  66. same, err := pathx.SameFile(wd, d)
  67. if err != nil {
  68. return "", err
  69. }
  70. trim := strings.TrimPrefix(projectCtx.WorkDir, projectCtx.Dir)
  71. if same {
  72. trim = strings.TrimPrefix(strings.ToLower(projectCtx.WorkDir), strings.ToLower(projectCtx.Dir))
  73. }
  74. return filepath.ToSlash(filepath.Join(projectCtx.Path, trim)), nil
  75. }
  76. func writeProperty(writer io.Writer, name, tag, comment string, tp spec.Type, indent int) error {
  77. util.WriteIndent(writer, indent)
  78. var err error
  79. if len(comment) > 0 {
  80. comment = strings.TrimPrefix(comment, "//")
  81. comment = "//" + comment
  82. _, err = fmt.Fprintf(writer, "%s %s %s %s\n", strings.Title(name), tp.Name(), tag, comment)
  83. } else {
  84. _, err = fmt.Fprintf(writer, "%s %s %s\n", strings.Title(name), tp.Name(), tag)
  85. }
  86. return err
  87. }
  88. func getAuths(api *spec.ApiSpec) []string {
  89. authNames := collection.NewSet()
  90. for _, g := range api.Service.Groups {
  91. jwt := g.GetAnnotation("jwt")
  92. if len(jwt) > 0 {
  93. authNames.Add(jwt)
  94. }
  95. }
  96. return authNames.KeysStr()
  97. }
  98. func getMiddleware(api *spec.ApiSpec) []string {
  99. result := collection.NewSet()
  100. for _, g := range api.Service.Groups {
  101. middleware := g.GetAnnotation("middleware")
  102. if len(middleware) > 0 {
  103. for _, item := range strings.Split(middleware, ",") {
  104. result.Add(strings.TrimSpace(item))
  105. }
  106. }
  107. }
  108. return result.KeysStr()
  109. }
  110. func formatCode(code string) string {
  111. ret, err := goformat.Source([]byte(code))
  112. if err != nil {
  113. return code
  114. }
  115. return string(ret)
  116. }
  117. func responseGoTypeName(r spec.Route, pkg ...string) string {
  118. if r.ResponseType == nil {
  119. return ""
  120. }
  121. resp := golangExpr(r.ResponseType, pkg...)
  122. switch r.ResponseType.(type) {
  123. case spec.DefineStruct:
  124. if !strings.HasPrefix(resp, "*") {
  125. return "*" + resp
  126. }
  127. }
  128. return resp
  129. }
  130. func requestGoTypeName(r spec.Route, pkg ...string) string {
  131. if r.RequestType == nil {
  132. return ""
  133. }
  134. return golangExpr(r.RequestType, pkg...)
  135. }
  136. func golangExpr(ty spec.Type, pkg ...string) string {
  137. switch v := ty.(type) {
  138. case spec.PrimitiveType:
  139. return v.RawName
  140. case spec.DefineStruct:
  141. if len(pkg) > 1 {
  142. panic("package cannot be more than 1")
  143. }
  144. if len(pkg) == 0 {
  145. return v.RawName
  146. }
  147. return fmt.Sprintf("%s.%s", pkg[0], strings.Title(v.RawName))
  148. case spec.ArrayType:
  149. if len(pkg) > 1 {
  150. panic("package cannot be more than 1")
  151. }
  152. if len(pkg) == 0 {
  153. return v.RawName
  154. }
  155. return fmt.Sprintf("[]%s", golangExpr(v.Value, pkg...))
  156. case spec.MapType:
  157. if len(pkg) > 1 {
  158. panic("package cannot be more than 1")
  159. }
  160. if len(pkg) == 0 {
  161. return v.RawName
  162. }
  163. return fmt.Sprintf("map[%s]%s", v.Key, golangExpr(v.Value, pkg...))
  164. case spec.PointerType:
  165. if len(pkg) > 1 {
  166. panic("package cannot be more than 1")
  167. }
  168. if len(pkg) == 0 {
  169. return v.RawName
  170. }
  171. return fmt.Sprintf("*%s", golangExpr(v.Type, pkg...))
  172. case spec.InterfaceType:
  173. return v.RawName
  174. }
  175. return ""
  176. }