|
@@ -2,6 +2,7 @@ package sqlx
|
|
|
|
|
|
import (
|
|
import (
|
|
"database/sql"
|
|
"database/sql"
|
|
|
|
+ "errors"
|
|
"testing"
|
|
"testing"
|
|
|
|
|
|
"github.com/DATA-DOG/go-sqlmock"
|
|
"github.com/DATA-DOG/go-sqlmock"
|
|
@@ -256,24 +257,6 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
|
|
-func TestUnmarshalRowStructWithTagsPtr(t *testing.T) {
|
|
|
|
- var value = new(struct {
|
|
|
|
- Age *int `db:"age"`
|
|
|
|
- Name string `db:"name"`
|
|
|
|
- })
|
|
|
|
-
|
|
|
|
- runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
|
|
|
- rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
|
|
|
- mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
|
|
|
-
|
|
|
|
- assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
|
|
|
- return unmarshalRow(value, rows, true)
|
|
|
|
- }, "select name, age from users where user=?", "anyone"))
|
|
|
|
- assert.Equal(t, "liao", value.Name)
|
|
|
|
- assert.Equal(t, 5, *value.Age)
|
|
|
|
- })
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
func TestUnmarshalRowsBool(t *testing.T) {
|
|
func TestUnmarshalRowsBool(t *testing.T) {
|
|
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
|
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
|
var expect = []bool{true, false}
|
|
var expect = []bool{true, false}
|
|
@@ -1001,6 +984,62 @@ func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func TestUnmarshalRowError(t *testing.T) {
|
|
|
|
+ tests := []struct {
|
|
|
|
+ name string
|
|
|
|
+ colErr error
|
|
|
|
+ scanErr error
|
|
|
|
+ err error
|
|
|
|
+ next int
|
|
|
|
+ validate func(err error)
|
|
|
|
+ }{
|
|
|
|
+ {
|
|
|
|
+ name: "with error",
|
|
|
|
+ err: errors.New("foo"),
|
|
|
|
+ validate: func(err error) {
|
|
|
|
+ assert.NotNil(t, err)
|
|
|
|
+ },
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "without next",
|
|
|
|
+ validate: func(err error) {
|
|
|
|
+ assert.Equal(t, ErrNotFound, err)
|
|
|
|
+ },
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "with error",
|
|
|
|
+ scanErr: errors.New("foo"),
|
|
|
|
+ next: 1,
|
|
|
|
+ validate: func(err error) {
|
|
|
|
+ assert.Equal(t, ErrNotFound, err)
|
|
|
|
+ },
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for _, test := range tests {
|
|
|
|
+ t.Run(test.name, func(t *testing.T) {
|
|
|
|
+ runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
|
|
|
+ rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
|
|
|
|
+ mock.ExpectQuery("select (.+) from users where user=?").WithArgs(
|
|
|
|
+ "anyone").WillReturnRows(rs)
|
|
|
|
+
|
|
|
|
+ var r struct {
|
|
|
|
+ User string `db:"user"`
|
|
|
|
+ Age int `db:"age"`
|
|
|
|
+ }
|
|
|
|
+ test.validate(query(db, func(rows *sql.Rows) error {
|
|
|
|
+ scanner := mockedScanner{
|
|
|
|
+ colErr: test.colErr,
|
|
|
|
+ scanErr: test.scanErr,
|
|
|
|
+ err: test.err,
|
|
|
|
+ }
|
|
|
|
+ return unmarshalRow(&r, &scanner, false)
|
|
|
|
+ }, "select age from users where user=?", "anyone"))
|
|
|
|
+ })
|
|
|
|
+ })
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
|
|
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
|
|
logx.Disable()
|
|
logx.Disable()
|
|
|
|
|
|
@@ -1016,3 +1055,30 @@ func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
|
|
t.Errorf("there were unfulfilled expectations: %s", err)
|
|
t.Errorf("there were unfulfilled expectations: %s", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+type mockedScanner struct {
|
|
|
|
+ colErr error
|
|
|
|
+ scanErr error
|
|
|
|
+ err error
|
|
|
|
+ next int
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (m *mockedScanner) Columns() ([]string, error) {
|
|
|
|
+ return nil, m.colErr
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (m *mockedScanner) Err() error {
|
|
|
|
+ return m.err
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (m *mockedScanner) Next() bool {
|
|
|
|
+ if m.next > 0 {
|
|
|
|
+ m.next--
|
|
|
|
+ return true
|
|
|
|
+ }
|
|
|
|
+ return false
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (m *mockedScanner) Scan(v ...interface{}) error {
|
|
|
|
+ return m.scanErr
|
|
|
|
+}
|