main.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. // Copyright 2022 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE and LICENSE.gogs file.
  4. // Copyright 2025 Huan-Gogs Authors. All rights reserved.
  5. // Use of this source code is governed by a MIT-style
  6. // license that can be found in the LICENSE file.
  7. package main
  8. import (
  9. "fmt"
  10. "log"
  11. "os"
  12. "sort"
  13. "strings"
  14. "github.com/olekukonko/tablewriter"
  15. "github.com/pkg/errors"
  16. "gopkg.in/DATA-DOG/go-sqlmock.v2"
  17. "gorm.io/driver/mysql"
  18. "gorm.io/driver/postgres"
  19. "gorm.io/driver/sqlite"
  20. "gorm.io/gorm"
  21. "gorm.io/gorm/clause"
  22. "gorm.io/gorm/schema"
  23. "github.com/SongZihuan/huan-gogs/internal/database"
  24. )
  25. //go:generate go run main.go ../../../docs/dev/database_schema.md
  26. func main() {
  27. w, err := os.Create(os.Args[1])
  28. if err != nil {
  29. log.Fatalf("Failed to create file: %v", err)
  30. }
  31. defer func() { _ = w.Close() }()
  32. conn, _, err := sqlmock.New()
  33. if err != nil {
  34. log.Fatalf("Failed to get mock connection: %v", err)
  35. }
  36. defer func() { _ = conn.Close() }()
  37. dialectors := []gorm.Dialector{
  38. postgres.New(postgres.Config{
  39. Conn: conn,
  40. }),
  41. mysql.New(mysql.Config{
  42. Conn: conn,
  43. SkipInitializeWithVersion: true,
  44. }),
  45. sqlite.Open(""),
  46. }
  47. collected := make([][]*tableInfo, 0, len(dialectors))
  48. for i, dialector := range dialectors {
  49. tableInfos, err := generate(dialector)
  50. if err != nil {
  51. log.Fatalf("Failed to get table info of %d: %v", i, err)
  52. }
  53. collected = append(collected, tableInfos)
  54. }
  55. for i, ti := range collected[0] {
  56. _, _ = w.WriteString(`# Table "` + ti.Name + `"`)
  57. _, _ = w.WriteString("\n\n")
  58. _, _ = w.WriteString("```\n")
  59. table := tablewriter.NewWriter(w)
  60. table.SetHeader([]string{"Field", "Column", "PostgreSQL", "MySQL", "SQLite3"})
  61. table.SetBorder(false)
  62. for j, f := range ti.Fields {
  63. table.Append([]string{
  64. f.Name, f.Column,
  65. strings.ToUpper(f.Type), // PostgreSQL
  66. strings.ToUpper(collected[1][i].Fields[j].Type), // MySQL
  67. strings.ToUpper(collected[2][i].Fields[j].Type), // SQLite3
  68. })
  69. }
  70. table.Render()
  71. _, _ = w.WriteString("\n")
  72. _, _ = w.WriteString("Primary keys: ")
  73. _, _ = w.WriteString(strings.Join(ti.PrimaryKeys, ", "))
  74. _, _ = w.WriteString("\n")
  75. if len(ti.Indexes) > 0 {
  76. _, _ = w.WriteString("Indexes: \n")
  77. for _, index := range ti.Indexes {
  78. _, _ = w.WriteString(fmt.Sprintf("\t%q", index.Name))
  79. if index.Class != "" {
  80. _, _ = w.WriteString(fmt.Sprintf(" %s", index.Class))
  81. }
  82. if index.Type != "" {
  83. _, _ = w.WriteString(fmt.Sprintf(", %s", index.Type))
  84. }
  85. if len(index.Fields) > 0 {
  86. fields := make([]string, len(index.Fields))
  87. for i := range index.Fields {
  88. fields[i] = index.Fields[i].DBName
  89. }
  90. _, _ = w.WriteString(fmt.Sprintf(" (%s)", strings.Join(fields, ", ")))
  91. }
  92. _, _ = w.WriteString("\n")
  93. }
  94. }
  95. _, _ = w.WriteString("```\n\n")
  96. }
  97. }
  98. type tableField struct {
  99. Name string
  100. Column string
  101. Type string
  102. }
  103. type tableInfo struct {
  104. Name string
  105. Fields []*tableField
  106. PrimaryKeys []string
  107. Indexes []schema.Index
  108. }
  109. // This function is derived from gorm.io/gorm/migrator/migrator.go:Migrator.CreateTable.
  110. func generate(dialector gorm.Dialector) ([]*tableInfo, error) {
  111. conn, err := gorm.Open(dialector,
  112. &gorm.Config{
  113. SkipDefaultTransaction: true,
  114. NamingStrategy: schema.NamingStrategy{
  115. SingularTable: true,
  116. },
  117. DryRun: true,
  118. DisableAutomaticPing: true,
  119. },
  120. )
  121. if err != nil {
  122. return nil, errors.Wrap(err, "open database")
  123. }
  124. m := conn.Migrator().(interface {
  125. RunWithValue(value any, fc func(*gorm.Statement) error) error
  126. FullDataTypeOf(*schema.Field) clause.Expr
  127. })
  128. tableInfos := make([]*tableInfo, 0, len(database.Tables))
  129. for _, table := range database.Tables {
  130. err = m.RunWithValue(table, func(stmt *gorm.Statement) error {
  131. fields := make([]*tableField, 0, len(stmt.Schema.DBNames))
  132. for _, field := range stmt.Schema.Fields {
  133. if field.DBName == "" {
  134. continue
  135. }
  136. tags := make([]string, 0)
  137. for tag := range field.TagSettings {
  138. if tag == "UNIQUE" {
  139. tags = append(tags, tag)
  140. }
  141. }
  142. typeSuffix := ""
  143. if len(tags) > 0 {
  144. typeSuffix = " " + strings.Join(tags, " ")
  145. }
  146. fields = append(fields, &tableField{
  147. Name: field.Name,
  148. Column: field.DBName,
  149. Type: m.FullDataTypeOf(field).SQL + typeSuffix,
  150. })
  151. }
  152. primaryKeys := make([]string, 0, len(stmt.Schema.PrimaryFields))
  153. if len(stmt.Schema.PrimaryFields) > 0 {
  154. for _, field := range stmt.Schema.PrimaryFields {
  155. primaryKeys = append(primaryKeys, field.DBName)
  156. }
  157. }
  158. var indexes []schema.Index
  159. for _, index := range stmt.Schema.ParseIndexes() {
  160. indexes = append(indexes, index)
  161. }
  162. sort.Slice(indexes, func(i, j int) bool {
  163. return indexes[i].Name < indexes[j].Name
  164. })
  165. tableInfos = append(tableInfos, &tableInfo{
  166. Name: stmt.Table,
  167. Fields: fields,
  168. PrimaryKeys: primaryKeys,
  169. Indexes: indexes,
  170. })
  171. return nil
  172. })
  173. if err != nil {
  174. return nil, errors.Wrap(err, "gather table information")
  175. }
  176. }
  177. return tableInfos, nil
  178. }