|
@@ -5,6 +5,7 @@
|
|
|
package db
|
|
|
|
|
|
import (
|
|
|
+ "database/sql"
|
|
|
"flag"
|
|
|
"fmt"
|
|
|
"os"
|
|
@@ -12,6 +13,7 @@ import (
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
+ "github.com/stretchr/testify/require"
|
|
|
"gorm.io/gorm"
|
|
|
"gorm.io/gorm/logger"
|
|
|
"gorm.io/gorm/schema"
|
|
@@ -59,16 +61,92 @@ func clearTables(t *testing.T, db *gorm.DB, tables ...interface{}) error {
|
|
|
}
|
|
|
|
|
|
func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB {
|
|
|
- t.Helper()
|
|
|
+ dbType := os.Getenv("GOGS_DATABASE_TYPE")
|
|
|
+
|
|
|
+ var dbName string
|
|
|
+ var dbOpts conf.DatabaseOpts
|
|
|
+ var cleanup func(db *gorm.DB)
|
|
|
+ switch dbType {
|
|
|
+ case "mysql":
|
|
|
+ dbOpts = conf.DatabaseOpts{
|
|
|
+ Type: "mysql",
|
|
|
+ Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"),
|
|
|
+ Name: dbName,
|
|
|
+ User: os.Getenv("MYSQL_USER"),
|
|
|
+ Password: os.Getenv("MYSQL_PASSWORD"),
|
|
|
+ }
|
|
|
+
|
|
|
+ dsn, err := newDSN(dbOpts)
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ sqlDB, err := sql.Open("mysql", dsn)
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ // Set up test database
|
|
|
+ dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
|
|
|
+ _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName))
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName))
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ dbOpts.Name = dbName
|
|
|
+
|
|
|
+ cleanup = func(db *gorm.DB) {
|
|
|
+ db.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName))
|
|
|
+ _ = sqlDB.Close()
|
|
|
+ }
|
|
|
+ case "postgres":
|
|
|
+ dbOpts = conf.DatabaseOpts{
|
|
|
+ Type: "postgres",
|
|
|
+ Host: os.ExpandEnv("$PGHOST:$PGPORT"),
|
|
|
+ Name: dbName,
|
|
|
+ Schema: "public",
|
|
|
+ User: os.Getenv("PGUSER"),
|
|
|
+ Password: os.Getenv("PGPASSWORD"),
|
|
|
+ SSLMode: os.Getenv("PGSSLMODE"),
|
|
|
+ }
|
|
|
+
|
|
|
+ dsn, err := newDSN(dbOpts)
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ sqlDB, err := sql.Open("pgx", dsn)
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ // Set up test database
|
|
|
+ dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
|
|
|
+ _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName))
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ dbOpts.Name = dbName
|
|
|
+
|
|
|
+ cleanup = func(db *gorm.DB) {
|
|
|
+ db.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName))
|
|
|
+ _ = sqlDB.Close()
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
|
|
|
+ dbOpts = conf.DatabaseOpts{
|
|
|
+ Type: "sqlite3",
|
|
|
+ Path: dbName,
|
|
|
+ }
|
|
|
+ cleanup = func(db *gorm.DB) {
|
|
|
+ sqlDB, err := db.DB()
|
|
|
+ if err == nil {
|
|
|
+ _ = sqlDB.Close()
|
|
|
+ }
|
|
|
+ _ = os.Remove(dbName)
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
|
|
|
now := time.Now().UTC().Truncate(time.Second)
|
|
|
db, err := openDB(
|
|
|
- conf.DatabaseOpts{
|
|
|
- Type: "sqlite3",
|
|
|
- Path: dbpath,
|
|
|
- },
|
|
|
+ dbOpts,
|
|
|
&gorm.Config{
|
|
|
+ SkipDefaultTransaction: true,
|
|
|
NamingStrategy: schema.NamingStrategy{
|
|
|
SingularTable: true,
|
|
|
},
|
|
@@ -77,27 +155,19 @@ func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB {
|
|
|
},
|
|
|
},
|
|
|
)
|
|
|
- if err != nil {
|
|
|
- t.Fatal(err)
|
|
|
- }
|
|
|
- t.Cleanup(func() {
|
|
|
- sqlDB, err := db.DB()
|
|
|
- if err == nil {
|
|
|
- _ = sqlDB.Close()
|
|
|
- }
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
|
+ t.Cleanup(func() {
|
|
|
if t.Failed() {
|
|
|
- t.Logf("Database %q left intact for inspection", dbpath)
|
|
|
+ t.Logf("Database %q left intact for inspection", dbName)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- _ = os.Remove(dbpath)
|
|
|
+ cleanup(db)
|
|
|
})
|
|
|
|
|
|
err = db.Migrator().AutoMigrate(tables...)
|
|
|
- if err != nil {
|
|
|
- t.Fatal(err)
|
|
|
- }
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
|
return db
|
|
|
}
|