genmain.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "strings"
  6. "text/template"
  7. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  8. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  9. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  10. "github.com/tal-tech/go-zero/tools/goctl/vars"
  11. )
  12. const mainTemplate = `package main
  13. import (
  14. "flag"
  15. {{.importPackages}}
  16. )
  17. var configFile = flag.String("f", "etc/{{.serviceName}}.yaml", "the config file")
  18. func main() {
  19. flag.Parse()
  20. var c config.Config
  21. conf.MustLoad(*configFile, &c)
  22. ctx := svc.NewServiceContext(c)
  23. server := rest.MustNewServer(c.RestConf)
  24. defer server.Stop()
  25. handler.RegisterHandlers(server, ctx)
  26. server.Start()
  27. }
  28. `
  29. func genMain(dir string, api *spec.ApiSpec) error {
  30. name := strings.ToLower(api.Service.Name)
  31. if strings.HasSuffix(name, "-api") {
  32. name = strings.ReplaceAll(name, "-api", "")
  33. }
  34. goFile := name + ".go"
  35. fp, created, err := util.MaybeCreateFile(dir, "", goFile)
  36. if err != nil {
  37. return err
  38. }
  39. if !created {
  40. return nil
  41. }
  42. defer fp.Close()
  43. parentPkg, err := getParentPackage(dir)
  44. if err != nil {
  45. return err
  46. }
  47. t := template.Must(template.New("mainTemplate").Parse(mainTemplate))
  48. buffer := new(bytes.Buffer)
  49. err = t.Execute(buffer, map[string]string{
  50. "importPackages": genMainImports(parentPkg),
  51. "serviceName": api.Service.Name,
  52. })
  53. if err != nil {
  54. return nil
  55. }
  56. formatCode := formatCode(buffer.String())
  57. _, err = fp.WriteString(formatCode)
  58. return err
  59. }
  60. func genMainImports(parentPkg string) string {
  61. var imports []string
  62. imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, configDir)))
  63. imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, handlerDir)))
  64. imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, contextDir)))
  65. imports = append(imports, fmt.Sprintf("\"%s/core/conf\"", vars.ProjectOpenSourceUrl))
  66. imports = append(imports, fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl))
  67. return strings.Join(imports, "\n\t")
  68. }