genmain.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "sort"
  6. "strings"
  7. "text/template"
  8. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  9. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  10. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  11. "github.com/tal-tech/go-zero/tools/goctl/vars"
  12. )
  13. const mainTemplate = `package main
  14. import (
  15. "flag"
  16. {{.importPackages}}
  17. )
  18. var configFile = flag.String("f", "etc/{{.serviceName}}.json", "the config file")
  19. func main() {
  20. flag.Parse()
  21. var c config.Config
  22. conf.MustLoad(*configFile, &c)
  23. ctx := svc.NewServiceContext(c)
  24. server := rest.MustNewServer(c.RestConf)
  25. defer server.Stop()
  26. handler.RegisterHandlers(server, ctx)
  27. server.Start()
  28. }
  29. `
  30. func genMain(dir string, api *spec.ApiSpec) error {
  31. name := strings.ToLower(api.Service.Name)
  32. if strings.HasSuffix(name, "-api") {
  33. name = strings.ReplaceAll(name, "-api", "")
  34. }
  35. goFile := name + ".go"
  36. fp, created, err := util.MaybeCreateFile(dir, "", goFile)
  37. if err != nil {
  38. return err
  39. }
  40. if !created {
  41. return nil
  42. }
  43. defer fp.Close()
  44. parentPkg, err := getParentPackage(dir)
  45. if err != nil {
  46. return err
  47. }
  48. t := template.Must(template.New("mainTemplate").Parse(mainTemplate))
  49. buffer := new(bytes.Buffer)
  50. err = t.Execute(buffer, map[string]string{
  51. "importPackages": genMainImports(parentPkg),
  52. "serviceName": api.Service.Name,
  53. })
  54. if err != nil {
  55. return nil
  56. }
  57. formatCode := formatCode(buffer.String())
  58. _, err = fp.WriteString(formatCode)
  59. return err
  60. }
  61. func genMainImports(parentPkg string) string {
  62. imports := []string{
  63. fmt.Sprintf("\"%s/core/conf\"", vars.ProjectOpenSourceUrl),
  64. fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl),
  65. }
  66. imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, configDir)))
  67. imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, handlerDir)))
  68. imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, contextDir)))
  69. sort.Strings(imports)
  70. return strings.Join(imports, "\n\t")
  71. }