gensvc.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "strings"
  6. "text/template"
  7. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  8. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  9. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  10. "github.com/tal-tech/go-zero/tools/goctl/vars"
  11. )
  12. const (
  13. contextFilename = "servicecontext.go"
  14. contextTemplate = `package svc
  15. import (
  16. {{.configImport}}
  17. )
  18. type ServiceContext struct {
  19. Config {{.config}}
  20. {{.middleware}}
  21. }
  22. func NewServiceContext(c {{.config}}) *ServiceContext {
  23. return &ServiceContext{
  24. Config: c,
  25. {{.middlewareAssignment}}
  26. }
  27. }
  28. `
  29. )
  30. func genServiceContext(dir string, api *spec.ApiSpec) error {
  31. fp, created, err := util.MaybeCreateFile(dir, contextDir, contextFilename)
  32. if err != nil {
  33. return err
  34. }
  35. if !created {
  36. return nil
  37. }
  38. defer fp.Close()
  39. var authNames = getAuths(api)
  40. var auths []string
  41. for _, item := range authNames {
  42. auths = append(auths, fmt.Sprintf("%s config.AuthConfig", item))
  43. }
  44. parentPkg, err := getParentPackage(dir)
  45. if err != nil {
  46. return err
  47. }
  48. text, err := ctlutil.LoadTemplate(category, contextTemplateFile, contextTemplate)
  49. if err != nil {
  50. return err
  51. }
  52. var middlewareStr string
  53. var middlewareAssignment string
  54. var middlewares = getMiddleware(api)
  55. err = genMiddleware(dir, middlewares)
  56. if err != nil {
  57. return err
  58. }
  59. for _, item := range middlewares {
  60. middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
  61. name := strings.TrimSuffix(item, "Middleware") + "Middleware"
  62. middlewareAssignment += fmt.Sprintf("%s: %s,\n", item, fmt.Sprintf("middleware.New%s().%s", strings.Title(name), "Handle"))
  63. }
  64. var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
  65. if len(middlewareStr) > 0 {
  66. configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\""
  67. configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceUrl)
  68. }
  69. t := template.Must(template.New("contextTemplate").Parse(text))
  70. buffer := new(bytes.Buffer)
  71. err = t.Execute(buffer, map[string]string{
  72. "configImport": configImport,
  73. "config": "config.Config",
  74. "middleware": middlewareStr,
  75. "middlewareAssignment": middlewareAssignment,
  76. })
  77. if err != nil {
  78. return err
  79. }
  80. formatCode := formatCode(buffer.String())
  81. _, err = fp.WriteString(formatCode)
  82. return err
  83. }