gensvc.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "text/template"
  6. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  7. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  8. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  9. "github.com/tal-tech/go-zero/tools/goctl/vars"
  10. )
  11. const (
  12. contextFilename = "servicecontext.go"
  13. contextTemplate = `package svc
  14. import (
  15. {{.configImport}}
  16. )
  17. type ServiceContext struct {
  18. Config {{.config}}
  19. {{.middleware}}
  20. }
  21. func NewServiceContext(c {{.config}}) *ServiceContext {
  22. return &ServiceContext{
  23. Config: c,
  24. {{.middlewareAssignment}}
  25. }
  26. }
  27. {{.middlewareImplement}}
  28. `
  29. middlewareImplementCode = `func %s(next http.HandlerFunc) http.HandlerFunc {
  30. return func(w http.ResponseWriter, r *http.Request) {
  31. // TODO generate middleware implement function, delete after code implementation
  32. }
  33. }
  34. `
  35. )
  36. func genServiceContext(dir string, api *spec.ApiSpec) error {
  37. fp, created, err := util.MaybeCreateFile(dir, contextDir, contextFilename)
  38. if err != nil {
  39. return err
  40. }
  41. if !created {
  42. return nil
  43. }
  44. defer fp.Close()
  45. var authNames = getAuths(api)
  46. var auths []string
  47. for _, item := range authNames {
  48. auths = append(auths, fmt.Sprintf("%s config.AuthConfig", item))
  49. }
  50. parentPkg, err := getParentPackage(dir)
  51. if err != nil {
  52. return err
  53. }
  54. text, err := ctlutil.LoadTemplate(category, contextTemplateFile, contextTemplate)
  55. if err != nil {
  56. return err
  57. }
  58. var middlewareStr string
  59. var middlewareAssignment string
  60. var middlewareImplement string
  61. for _, item := range getMiddleware(api) {
  62. middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
  63. middlewareAssignment += fmt.Sprintf("%s: %s,\n", item, item)
  64. middlewareImplement += fmt.Sprintf(middlewareImplementCode, item)
  65. }
  66. var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
  67. if len(middlewareStr) > 0 {
  68. configImport += "\n\t\"net/http\""
  69. configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceUrl)
  70. }
  71. t := template.Must(template.New("contextTemplate").Parse(text))
  72. buffer := new(bytes.Buffer)
  73. err = t.Execute(buffer, map[string]string{
  74. "configImport": configImport,
  75. "config": "config.Config",
  76. "middleware": middlewareStr,
  77. "middlewareAssignment": middlewareAssignment,
  78. "middlewareImplement": middlewareImplement,
  79. })
  80. if err != nil {
  81. return nil
  82. }
  83. formatCode := formatCode(buffer.String())
  84. _, err = fp.WriteString(formatCode)
  85. return err
  86. }