gensvc.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "text/template"
  6. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  7. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  8. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  9. "github.com/tal-tech/go-zero/tools/goctl/vars"
  10. )
  11. const (
  12. contextFilename = "servicecontext.go"
  13. contextTemplate = `package svc
  14. import (
  15. {{.configImport}}
  16. )
  17. type ServiceContext struct {
  18. Config {{.config}}
  19. {{.middleware}}
  20. }
  21. func NewServiceContext(c {{.config}}) *ServiceContext {
  22. return &ServiceContext{Config: c}
  23. }
  24. `
  25. )
  26. func genServiceContext(dir string, api *spec.ApiSpec) error {
  27. fp, created, err := util.MaybeCreateFile(dir, contextDir, contextFilename)
  28. if err != nil {
  29. return err
  30. }
  31. if !created {
  32. return nil
  33. }
  34. defer fp.Close()
  35. var authNames = getAuths(api)
  36. var auths []string
  37. for _, item := range authNames {
  38. auths = append(auths, fmt.Sprintf("%s config.AuthConfig", item))
  39. }
  40. parentPkg, err := getParentPackage(dir)
  41. if err != nil {
  42. return err
  43. }
  44. text, err := ctlutil.LoadTemplate(category, contextTemplateFile, contextTemplate)
  45. if err != nil {
  46. return err
  47. }
  48. var middlewareStr string
  49. for _, item := range getMiddleware(api) {
  50. middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
  51. }
  52. var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
  53. if len(middlewareStr) > 0 {
  54. configImport += fmt.Sprintf("\n\"%s/rest\"", vars.ProjectOpenSourceUrl)
  55. }
  56. t := template.Must(template.New("contextTemplate").Parse(text))
  57. buffer := new(bytes.Buffer)
  58. err = t.Execute(buffer, map[string]string{
  59. "configImport": configImport,
  60. "config": "config.Config",
  61. "middleware": middlewareStr,
  62. })
  63. if err != nil {
  64. return nil
  65. }
  66. formatCode := formatCode(buffer.String())
  67. _, err = fp.WriteString(formatCode)
  68. return err
  69. }