gensvc.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "text/template"
  6. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  7. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  8. "github.com/tal-tech/go-zero/tools/goctl/templatex"
  9. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  10. "github.com/tal-tech/go-zero/tools/goctl/vars"
  11. )
  12. const (
  13. contextFilename = "servicecontext.go"
  14. contextTemplate = `package svc
  15. import (
  16. {{.configImport}}
  17. )
  18. type ServiceContext struct {
  19. Config {{.config}}
  20. {{.middleware}}
  21. }
  22. func NewServiceContext(c {{.config}}) *ServiceContext {
  23. return &ServiceContext{Config: c}
  24. }
  25. `
  26. )
  27. func genServiceContext(dir string, api *spec.ApiSpec) error {
  28. fp, created, err := util.MaybeCreateFile(dir, contextDir, contextFilename)
  29. if err != nil {
  30. return err
  31. }
  32. if !created {
  33. return nil
  34. }
  35. defer fp.Close()
  36. var authNames = getAuths(api)
  37. var auths []string
  38. for _, item := range authNames {
  39. auths = append(auths, fmt.Sprintf("%s config.AuthConfig", item))
  40. }
  41. parentPkg, err := getParentPackage(dir)
  42. if err != nil {
  43. return err
  44. }
  45. text, err := templatex.LoadTemplate(category, contextTemplateFile, contextTemplate)
  46. if err != nil {
  47. return err
  48. }
  49. var middlewareStr string
  50. for _, item := range getMiddleware(api) {
  51. middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
  52. }
  53. var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
  54. if len(middlewareStr) > 0 {
  55. configImport += fmt.Sprintf("\n\"%s/rest\"", vars.ProjectOpenSourceUrl)
  56. }
  57. t := template.Must(template.New("contextTemplate").Parse(text))
  58. buffer := new(bytes.Buffer)
  59. err = t.Execute(buffer, map[string]string{
  60. "configImport": configImport,
  61. "config": "config.Config",
  62. "middleware": middlewareStr,
  63. })
  64. if err != nil {
  65. return nil
  66. }
  67. formatCode := formatCode(buffer.String())
  68. _, err = fp.WriteString(formatCode)
  69. return err
  70. }