util.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. goformat "go/format"
  6. "io"
  7. "path/filepath"
  8. "strings"
  9. "text/template"
  10. "github.com/tal-tech/go-zero/core/collection"
  11. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  12. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  13. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  14. "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
  15. )
  16. type fileGenConfig struct {
  17. dir string
  18. subdir string
  19. filename string
  20. templateName string
  21. category string
  22. templateFile string
  23. builtinTemplate string
  24. data interface{}
  25. }
  26. func genFile(c fileGenConfig) error {
  27. fp, created, err := util.MaybeCreateFile(c.dir, c.subdir, c.filename)
  28. if err != nil {
  29. return err
  30. }
  31. if !created {
  32. return nil
  33. }
  34. defer fp.Close()
  35. var text string
  36. if len(c.category) == 0 || len(c.templateFile) == 0 {
  37. text = c.builtinTemplate
  38. } else {
  39. text, err = ctlutil.LoadTemplate(c.category, c.templateFile, c.builtinTemplate)
  40. if err != nil {
  41. return err
  42. }
  43. }
  44. t := template.Must(template.New(c.templateName).Parse(text))
  45. buffer := new(bytes.Buffer)
  46. err = t.Execute(buffer, c.data)
  47. if err != nil {
  48. return err
  49. }
  50. code := formatCode(buffer.String())
  51. _, err = fp.WriteString(code)
  52. return err
  53. }
  54. func getParentPackage(dir string) (string, error) {
  55. abs, err := filepath.Abs(dir)
  56. if err != nil {
  57. return "", err
  58. }
  59. projectCtx, err := ctx.Prepare(abs)
  60. if err != nil {
  61. return "", err
  62. }
  63. return filepath.ToSlash(filepath.Join(projectCtx.Path, strings.TrimPrefix(projectCtx.WorkDir, projectCtx.Dir))), nil
  64. }
  65. func writeProperty(writer io.Writer, name, tp, tag, comment string, indent int) error {
  66. util.WriteIndent(writer, indent)
  67. var err error
  68. if len(comment) > 0 {
  69. comment = strings.TrimPrefix(comment, "//")
  70. comment = "//" + comment
  71. _, err = fmt.Fprintf(writer, "%s %s %s %s\n", strings.Title(name), tp, tag, comment)
  72. } else {
  73. _, err = fmt.Fprintf(writer, "%s %s %s\n", strings.Title(name), tp, tag)
  74. }
  75. return err
  76. }
  77. func getAuths(api *spec.ApiSpec) []string {
  78. authNames := collection.NewSet()
  79. for _, g := range api.Service.Groups {
  80. if value, ok := util.GetAnnotationValue(g.Annotations, "server", "jwt"); ok {
  81. authNames.Add(value)
  82. }
  83. if value, ok := util.GetAnnotationValue(g.Annotations, "server", "signature"); ok {
  84. authNames.Add(value)
  85. }
  86. }
  87. return authNames.KeysStr()
  88. }
  89. func getMiddleware(api *spec.ApiSpec) []string {
  90. result := collection.NewSet()
  91. for _, g := range api.Service.Groups {
  92. if value, ok := util.GetAnnotationValue(g.Annotations, "server", "middleware"); ok {
  93. for _, item := range strings.Split(value, ",") {
  94. result.Add(strings.TrimSpace(item))
  95. }
  96. }
  97. }
  98. return result.KeysStr()
  99. }
  100. func formatCode(code string) string {
  101. ret, err := goformat.Source([]byte(code))
  102. if err != nil {
  103. return code
  104. }
  105. return string(ret)
  106. }