genmain.go 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "strings"
  6. "text/template"
  7. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  8. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  9. "github.com/tal-tech/go-zero/tools/goctl/templatex"
  10. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  11. "github.com/tal-tech/go-zero/tools/goctl/vars"
  12. )
  13. const mainTemplate = `package main
  14. import (
  15. "flag"
  16. "fmt"
  17. {{.importPackages}}
  18. )
  19. var configFile = flag.String("f", "etc/{{.serviceName}}.yaml", "the config file")
  20. func main() {
  21. flag.Parse()
  22. var c config.Config
  23. conf.MustLoad(*configFile, &c)
  24. ctx := svc.NewServiceContext(c)
  25. server := rest.MustNewServer(c.RestConf)
  26. defer server.Stop()
  27. handler.RegisterHandlers(server, ctx)
  28. fmt.Printf("Starting server at %s:%d...\n", c.Host, c.Port)
  29. server.Start()
  30. }
  31. `
  32. func genMain(dir string, api *spec.ApiSpec) error {
  33. name := strings.ToLower(api.Service.Name)
  34. if strings.HasSuffix(name, "-api") {
  35. name = strings.ReplaceAll(name, "-api", "")
  36. }
  37. goFile := name + ".go"
  38. fp, created, err := util.MaybeCreateFile(dir, "", goFile)
  39. if err != nil {
  40. return err
  41. }
  42. if !created {
  43. return nil
  44. }
  45. defer fp.Close()
  46. parentPkg, err := getParentPackage(dir)
  47. if err != nil {
  48. return err
  49. }
  50. text, err := templatex.LoadTemplate(category, mainTemplateFile, mainTemplate)
  51. if err != nil {
  52. return err
  53. }
  54. t := template.Must(template.New("mainTemplate").Parse(text))
  55. buffer := new(bytes.Buffer)
  56. err = t.Execute(buffer, map[string]string{
  57. "importPackages": genMainImports(parentPkg),
  58. "serviceName": api.Service.Name,
  59. })
  60. if err != nil {
  61. return nil
  62. }
  63. formatCode := formatCode(buffer.String())
  64. _, err = fp.WriteString(formatCode)
  65. return err
  66. }
  67. func genMainImports(parentPkg string) string {
  68. var imports []string
  69. imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, configDir)))
  70. imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, handlerDir)))
  71. imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, contextDir)))
  72. imports = append(imports, fmt.Sprintf("\"%s/core/conf\"", vars.ProjectOpenSourceUrl))
  73. imports = append(imports, fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl))
  74. return strings.Join(imports, "\n\t")
  75. }