gensvc.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "strings"
  6. "text/template"
  7. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  8. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  9. "github.com/tal-tech/go-zero/tools/goctl/config"
  10. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  11. "github.com/tal-tech/go-zero/tools/goctl/util/format"
  12. "github.com/tal-tech/go-zero/tools/goctl/vars"
  13. )
  14. const (
  15. contextFilename = "service_context"
  16. contextTemplate = `package svc
  17. import (
  18. {{.configImport}}
  19. )
  20. type ServiceContext struct {
  21. Config {{.config}}
  22. {{.middleware}}
  23. }
  24. func NewServiceContext(c {{.config}}) *ServiceContext {
  25. return &ServiceContext{
  26. Config: c,
  27. {{.middlewareAssignment}}
  28. }
  29. }
  30. `
  31. )
  32. func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error {
  33. filename, err := format.FileNamingFormat(cfg.NamingFormat, contextFilename)
  34. if err != nil {
  35. return err
  36. }
  37. fp, created, err := util.MaybeCreateFile(dir, contextDir, filename+".go")
  38. if err != nil {
  39. return err
  40. }
  41. if !created {
  42. return nil
  43. }
  44. defer fp.Close()
  45. var authNames = getAuths(api)
  46. var auths []string
  47. for _, item := range authNames {
  48. auths = append(auths, fmt.Sprintf("%s config.AuthConfig", item))
  49. }
  50. parentPkg, err := getParentPackage(dir)
  51. if err != nil {
  52. return err
  53. }
  54. text, err := ctlutil.LoadTemplate(category, contextTemplateFile, contextTemplate)
  55. if err != nil {
  56. return err
  57. }
  58. var middlewareStr string
  59. var middlewareAssignment string
  60. var middlewares = getMiddleware(api)
  61. for _, item := range middlewares {
  62. middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
  63. name := strings.TrimSuffix(item, "Middleware") + "Middleware"
  64. middlewareAssignment += fmt.Sprintf("%s: %s,\n", item, fmt.Sprintf("middleware.New%s().%s", strings.Title(name), "Handle"))
  65. }
  66. var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
  67. if len(middlewareStr) > 0 {
  68. configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\""
  69. configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceUrl)
  70. }
  71. t := template.Must(template.New("contextTemplate").Parse(text))
  72. buffer := new(bytes.Buffer)
  73. err = t.Execute(buffer, map[string]string{
  74. "configImport": configImport,
  75. "config": "config.Config",
  76. "middleware": middlewareStr,
  77. "middlewareAssignment": middlewareAssignment,
  78. })
  79. if err != nil {
  80. return err
  81. }
  82. formatCode := formatCode(buffer.String())
  83. _, err = fp.WriteString(formatCode)
  84. return err
  85. }