sql.go 865 B

12345678910111213141516171819202122232425262728293031323334353637
  1. package dbtest
  2. import (
  3. "database/sql"
  4. "testing"
  5. "github.com/DATA-DOG/go-sqlmock"
  6. "github.com/stretchr/testify/assert"
  7. )
  8. // RunTest runs a test function with a mock database.
  9. func RunTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
  10. db, mock, err := sqlmock.New()
  11. if err != nil {
  12. t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
  13. }
  14. defer func() {
  15. _ = db.Close()
  16. }()
  17. fn(db, mock)
  18. if err = mock.ExpectationsWereMet(); err != nil {
  19. t.Errorf("there were unfulfilled expectations: %s", err)
  20. }
  21. }
  22. // RunTxTest runs a test function with a mock database in a transaction.
  23. func RunTxTest(t *testing.T, f func(tx *sql.Tx, mock sqlmock.Sqlmock)) {
  24. RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  25. mock.ExpectBegin()
  26. tx, err := db.Begin()
  27. if assert.NoError(t, err) {
  28. f(tx, mock)
  29. }
  30. })
  31. }