util.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "os"
  7. "strings"
  8. "text/template"
  9. "github.com/zeromicro/go-zero/core/collection"
  10. "github.com/zeromicro/go-zero/tools/goctl/api/spec"
  11. "github.com/zeromicro/go-zero/tools/goctl/api/util"
  12. "github.com/zeromicro/go-zero/tools/goctl/pkg/golang"
  13. "github.com/zeromicro/go-zero/tools/goctl/util/pathx"
  14. )
  15. type fileGenConfig struct {
  16. dir string
  17. subdir string
  18. filename string
  19. templateName string
  20. category string
  21. templateFile string
  22. builtinTemplate string
  23. data interface{}
  24. }
  25. func genFile(c fileGenConfig) error {
  26. fp, created, err := util.MaybeCreateFile(c.dir, c.subdir, c.filename)
  27. if err != nil {
  28. return err
  29. }
  30. if !created {
  31. return nil
  32. }
  33. defer fp.Close()
  34. var text string
  35. if len(c.category) == 0 || len(c.templateFile) == 0 {
  36. text = c.builtinTemplate
  37. } else {
  38. text, err = pathx.LoadTemplate(c.category, c.templateFile, c.builtinTemplate)
  39. if err != nil {
  40. return err
  41. }
  42. }
  43. t := template.Must(template.New(c.templateName).Parse(text))
  44. buffer := new(bytes.Buffer)
  45. err = t.Execute(buffer, c.data)
  46. if err != nil {
  47. return err
  48. }
  49. code := golang.FormatCode(buffer.String())
  50. _, err = fp.WriteString(code)
  51. return err
  52. }
  53. func writeProperty(writer io.Writer, name, tag, comment string, tp spec.Type, indent int, api *spec.ApiSpec) error {
  54. util.WriteIndent(writer, indent)
  55. var err error
  56. var refPropertyName = tp.Name()
  57. if isCustomType(refPropertyName) {
  58. strs := getRefProperty(api, refPropertyName, name)
  59. _, err = fmt.Fprintf(writer, "%s\n", strs)
  60. } else {
  61. if len(comment) > 0 {
  62. comment = strings.TrimPrefix(comment, "//")
  63. comment = "//" + comment
  64. _, err = fmt.Fprintf(writer, "%s %s %s %s\n", strings.Title(name), tp.Name(), tag, comment)
  65. } else {
  66. _, err = fmt.Fprintf(writer, "%s %s %s\n", strings.Title(name), tp.Name(), tag)
  67. }
  68. }
  69. return err
  70. }
  71. func getAuths(api *spec.ApiSpec) []string {
  72. authNames := collection.NewSet()
  73. for _, g := range api.Service.Groups {
  74. jwt := g.GetAnnotation("jwt")
  75. if len(jwt) > 0 {
  76. authNames.Add(jwt)
  77. }
  78. }
  79. return authNames.KeysStr()
  80. }
  81. func getJwtTrans(api *spec.ApiSpec) []string {
  82. jwtTransList := collection.NewSet()
  83. for _, g := range api.Service.Groups {
  84. jt := g.GetAnnotation(jwtTransKey)
  85. if len(jt) > 0 {
  86. jwtTransList.Add(jt)
  87. }
  88. }
  89. return jwtTransList.KeysStr()
  90. }
  91. func getMiddleware(api *spec.ApiSpec) []string {
  92. result := collection.NewSet()
  93. for _, g := range api.Service.Groups {
  94. middleware := g.GetAnnotation("middleware")
  95. if len(middleware) > 0 {
  96. for _, item := range strings.Split(middleware, ",") {
  97. result.Add(strings.TrimSpace(item))
  98. }
  99. }
  100. }
  101. return result.KeysStr()
  102. }
  103. func responseGoTypeName(r spec.Route, pkg ...string) string {
  104. if r.ResponseType == nil {
  105. return ""
  106. }
  107. resp := golangExpr(r.ResponseType, pkg...)
  108. switch r.ResponseType.(type) {
  109. case spec.DefineStruct:
  110. if !strings.HasPrefix(resp, "*") {
  111. return "*" + resp
  112. }
  113. }
  114. return resp
  115. }
  116. func requestGoTypeName(r spec.Route, pkg ...string) string {
  117. if r.RequestType == nil {
  118. return ""
  119. }
  120. return golangExpr(r.RequestType, pkg...)
  121. }
  122. func golangExpr(ty spec.Type, pkg ...string) string {
  123. switch v := ty.(type) {
  124. case spec.PrimitiveType:
  125. return v.RawName
  126. case spec.DefineStruct:
  127. if len(pkg) > 1 {
  128. panic("package cannot be more than 1")
  129. }
  130. if len(pkg) == 0 {
  131. return v.RawName
  132. }
  133. return fmt.Sprintf("%s.%s", pkg[0], strings.Title(v.RawName))
  134. case spec.ArrayType:
  135. if len(pkg) > 1 {
  136. panic("package cannot be more than 1")
  137. }
  138. if len(pkg) == 0 {
  139. return v.RawName
  140. }
  141. return fmt.Sprintf("[]%s", golangExpr(v.Value, pkg...))
  142. case spec.MapType:
  143. if len(pkg) > 1 {
  144. panic("package cannot be more than 1")
  145. }
  146. if len(pkg) == 0 {
  147. return v.RawName
  148. }
  149. return fmt.Sprintf("map[%s]%s", v.Key, golangExpr(v.Value, pkg...))
  150. case spec.PointerType:
  151. if len(pkg) > 1 {
  152. panic("package cannot be more than 1")
  153. }
  154. if len(pkg) == 0 {
  155. return v.RawName
  156. }
  157. return fmt.Sprintf("*%s", golangExpr(v.Type, pkg...))
  158. case spec.InterfaceType:
  159. return v.RawName
  160. }
  161. return ""
  162. }
  163. func isCustomType(t string) bool {
  164. var builtinType = []string{"string", "bool", "int", "uint", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "float32", "float64", "uintptr", "complex64", "complex128"}
  165. var is bool = true
  166. for _, v := range builtinType {
  167. if t == v {
  168. is = false
  169. break
  170. }
  171. }
  172. return is
  173. }
  174. // Generate nested types recursively
  175. func getRefProperty(api *spec.ApiSpec, refPropertyName string, name string) string {
  176. var str string = ""
  177. for _, t := range api.Types {
  178. if strings.TrimLeft(refPropertyName, "*") == t.Name() {
  179. switch tm := t.(type) {
  180. case spec.DefineStruct:
  181. for _, m := range tm.Members {
  182. if isCustomType(m.Type.Name()) {
  183. // recursive
  184. str += getRefProperty(api, m.Type.Name(), m.Name)
  185. } else {
  186. if len(m.Comment) > 0 {
  187. comment := strings.TrimPrefix(m.Comment, "//")
  188. comment = "//" + comment
  189. str += fmt.Sprintf("%s %s %s %s\n\t", m.Name, m.Type.Name(), m.Tag, comment)
  190. } else {
  191. str += fmt.Sprintf("%s %s %s\n\t", m.Name, m.Type.Name(), m.Tag)
  192. }
  193. }
  194. }
  195. }
  196. }
  197. }
  198. if name == "" {
  199. temp := `${str}`
  200. return os.Expand(temp, func(k string) string {
  201. return str
  202. })
  203. } else {
  204. temp := `${name} struct {
  205. ${str}}`
  206. return os.Expand(temp, func(k string) string {
  207. return map[string]string{
  208. "name": name,
  209. "str": str,
  210. }[k]
  211. })
  212. }
  213. }