1
0

main_test.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "database/sql"
  7. "flag"
  8. "fmt"
  9. "os"
  10. "path/filepath"
  11. "testing"
  12. "time"
  13. "github.com/stretchr/testify/require"
  14. "gorm.io/gorm"
  15. "gorm.io/gorm/logger"
  16. "gorm.io/gorm/schema"
  17. log "unknwon.dev/clog/v2"
  18. "gogs.io/gogs/internal/conf"
  19. "gogs.io/gogs/internal/testutil"
  20. )
  21. func TestMain(m *testing.M) {
  22. flag.Parse()
  23. level := logger.Silent
  24. if !testing.Verbose() {
  25. // Remove the primary logger and register a noop logger.
  26. log.Remove(log.DefaultConsoleName)
  27. err := log.New("noop", testutil.InitNoopLogger)
  28. if err != nil {
  29. fmt.Println(err)
  30. os.Exit(1)
  31. }
  32. } else {
  33. level = logger.Info
  34. }
  35. // NOTE: AutoMigrate does not respect logger passed in gorm.Config.
  36. logger.Default = logger.Default.LogMode(level)
  37. os.Exit(m.Run())
  38. }
  39. // clearTables removes all rows from given tables.
  40. func clearTables(t *testing.T, db *gorm.DB, tables ...interface{}) error {
  41. if t.Failed() {
  42. return nil
  43. }
  44. for _, t := range tables {
  45. err := db.Where("TRUE").Delete(t).Error
  46. if err != nil {
  47. return err
  48. }
  49. }
  50. return nil
  51. }
  52. func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB {
  53. dbType := os.Getenv("GOGS_DATABASE_TYPE")
  54. var dbName string
  55. var dbOpts conf.DatabaseOpts
  56. var cleanup func(db *gorm.DB)
  57. switch dbType {
  58. case "mysql":
  59. dbOpts = conf.DatabaseOpts{
  60. Type: "mysql",
  61. Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"),
  62. Name: dbName,
  63. User: os.Getenv("MYSQL_USER"),
  64. Password: os.Getenv("MYSQL_PASSWORD"),
  65. }
  66. dsn, err := newDSN(dbOpts)
  67. require.NoError(t, err)
  68. sqlDB, err := sql.Open("mysql", dsn)
  69. require.NoError(t, err)
  70. // Set up test database
  71. dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
  72. _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName))
  73. require.NoError(t, err)
  74. _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName))
  75. require.NoError(t, err)
  76. dbOpts.Name = dbName
  77. cleanup = func(db *gorm.DB) {
  78. db.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName))
  79. _ = sqlDB.Close()
  80. }
  81. case "postgres":
  82. dbOpts = conf.DatabaseOpts{
  83. Type: "postgres",
  84. Host: os.ExpandEnv("$PGHOST:$PGPORT"),
  85. Name: dbName,
  86. Schema: "public",
  87. User: os.Getenv("PGUSER"),
  88. Password: os.Getenv("PGPASSWORD"),
  89. SSLMode: os.Getenv("PGSSLMODE"),
  90. }
  91. dsn, err := newDSN(dbOpts)
  92. require.NoError(t, err)
  93. sqlDB, err := sql.Open("pgx", dsn)
  94. require.NoError(t, err)
  95. // Set up test database
  96. dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
  97. _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
  98. require.NoError(t, err)
  99. _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName))
  100. require.NoError(t, err)
  101. dbOpts.Name = dbName
  102. cleanup = func(db *gorm.DB) {
  103. db.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName))
  104. _ = sqlDB.Close()
  105. }
  106. default:
  107. dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
  108. dbOpts = conf.DatabaseOpts{
  109. Type: "sqlite3",
  110. Path: dbName,
  111. }
  112. cleanup = func(db *gorm.DB) {
  113. sqlDB, err := db.DB()
  114. if err == nil {
  115. _ = sqlDB.Close()
  116. }
  117. _ = os.Remove(dbName)
  118. }
  119. }
  120. now := time.Now().UTC().Truncate(time.Second)
  121. db, err := openDB(
  122. dbOpts,
  123. &gorm.Config{
  124. SkipDefaultTransaction: true,
  125. NamingStrategy: schema.NamingStrategy{
  126. SingularTable: true,
  127. },
  128. NowFunc: func() time.Time {
  129. return now
  130. },
  131. },
  132. )
  133. require.NoError(t, err)
  134. t.Cleanup(func() {
  135. if t.Failed() {
  136. t.Logf("Database %q left intact for inspection", dbName)
  137. return
  138. }
  139. cleanup(db)
  140. })
  141. err = db.Migrator().AutoMigrate(tables...)
  142. require.NoError(t, err)
  143. return db
  144. }