genmain.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "path"
  6. "sort"
  7. "strings"
  8. "text/template"
  9. "zero/tools/goctl/api/spec"
  10. "zero/tools/goctl/api/util"
  11. )
  12. const mainTemplate = `package main
  13. import (
  14. "flag"
  15. {{.importPackages}}
  16. )
  17. var configFile = flag.String("f", "etc/{{.serviceName}}.json", "the config file")
  18. func main() {
  19. flag.Parse()
  20. var c config.Config
  21. conf.MustLoad(*configFile, &c)
  22. ctx := svc.NewServiceContext(c)
  23. engine := rest.MustNewEngine(c.RestConf)
  24. defer engine.Stop()
  25. handler.RegisterHandlers(engine, ctx)
  26. engine.Start()
  27. }
  28. `
  29. func genMain(dir string, api *spec.ApiSpec) error {
  30. name := strings.ToLower(api.Service.Name)
  31. if strings.HasSuffix(name, "-api") {
  32. name = strings.ReplaceAll(name, "-api", "")
  33. }
  34. goFile := name + ".go"
  35. fp, created, err := util.MaybeCreateFile(dir, "", goFile)
  36. if err != nil {
  37. return err
  38. }
  39. if !created {
  40. return nil
  41. }
  42. defer fp.Close()
  43. parentPkg, err := getParentPackage(dir)
  44. if err != nil {
  45. return err
  46. }
  47. t := template.Must(template.New("mainTemplate").Parse(mainTemplate))
  48. buffer := new(bytes.Buffer)
  49. err = t.Execute(buffer, map[string]string{
  50. "importPackages": genMainImports(parentPkg),
  51. "serviceName": api.Service.Name,
  52. })
  53. if err != nil {
  54. return nil
  55. }
  56. formatCode := formatCode(buffer.String())
  57. _, err = fp.WriteString(formatCode)
  58. return err
  59. }
  60. func genMainImports(parentPkg string) string {
  61. imports := []string{
  62. `"zero/core/conf"`,
  63. `"zero/rest"`,
  64. }
  65. imports = append(imports, fmt.Sprintf("\"%s\"", path.Join(parentPkg, configDir)))
  66. imports = append(imports, fmt.Sprintf("\"%s\"", path.Join(parentPkg, handlerDir)))
  67. imports = append(imports, fmt.Sprintf("\"%s\"", path.Join(parentPkg, contextDir)))
  68. sort.Strings(imports)
  69. return strings.Join(imports, "\n\t")
  70. }