model.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package gen
  2. import (
  3. "bytes"
  4. "go/format"
  5. "strings"
  6. "text/template"
  7. "github.com/tal-tech/go-zero/core/logx"
  8. sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
  9. )
  10. func GenModel(table *InnerTable) (string, error) {
  11. t, err := template.New("model").Parse(sqltemplate.Model)
  12. if err != nil {
  13. return "", nil
  14. }
  15. modelBuffer := new(bytes.Buffer)
  16. importsCode, err := genImports(table)
  17. if err != nil {
  18. return "", err
  19. }
  20. varsCode, err := genVars(table)
  21. if err != nil {
  22. return "", err
  23. }
  24. typesCode, err := genTypes(table)
  25. if err != nil {
  26. return "", err
  27. }
  28. newCode, err := genNew(table)
  29. if err != nil {
  30. return "", err
  31. }
  32. insertCode, err := genInsert(table)
  33. if err != nil {
  34. return "", err
  35. }
  36. var findCode = make([]string, 0)
  37. findOneCode, err := genFindOne(table)
  38. if err != nil {
  39. return "", err
  40. }
  41. findOneByFieldCode, err := genFineOneByField(table)
  42. if err != nil {
  43. return "", err
  44. }
  45. findAllCode, err := genFindAllByField(table)
  46. if err != nil {
  47. return "", err
  48. }
  49. findLimitCode, err := genFindLimitByField(table)
  50. if err != nil {
  51. return "", err
  52. }
  53. findCode = append(findCode, findOneCode, findOneByFieldCode, findAllCode, findLimitCode)
  54. updateCode, err := genUpdate(table)
  55. if err != nil {
  56. return "", err
  57. }
  58. deleteCode, err := genDelete(table)
  59. if err != nil {
  60. return "", err
  61. }
  62. err = t.Execute(modelBuffer, map[string]interface{}{
  63. "imports": importsCode,
  64. "vars": varsCode,
  65. "types": typesCode,
  66. "new": newCode,
  67. "insert": insertCode,
  68. "find": strings.Join(findCode, "\r\n"),
  69. "update": updateCode,
  70. "delete": deleteCode,
  71. })
  72. if err != nil {
  73. return "", err
  74. }
  75. result := modelBuffer.String()
  76. bts, err := format.Source([]byte(result))
  77. if err != nil {
  78. logx.Errorf("%+v", err)
  79. return "", err
  80. }
  81. return string(bts), nil
  82. }