|
@@ -0,0 +1,245 @@
|
|
|
+package sqlx
|
|
|
+
|
|
|
+import (
|
|
|
+ "database/sql"
|
|
|
+ "errors"
|
|
|
+ "testing"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/stretchr/testify/assert"
|
|
|
+)
|
|
|
+
|
|
|
+var errMockedPlaceholder = errors.New("placeholder")
|
|
|
+
|
|
|
+func TestStmt_exec(t *testing.T) {
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ args []interface{}
|
|
|
+ delay bool
|
|
|
+ formatError bool
|
|
|
+ hasError bool
|
|
|
+ err error
|
|
|
+ lastInsertId int64
|
|
|
+ rowsAffected int64
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "normal",
|
|
|
+ args: []interface{}{1},
|
|
|
+ lastInsertId: 1,
|
|
|
+ rowsAffected: 2,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "wrong format",
|
|
|
+ args: []interface{}{1, 2},
|
|
|
+ formatError: true,
|
|
|
+ hasError: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "exec error",
|
|
|
+ args: []interface{}{1},
|
|
|
+ hasError: true,
|
|
|
+ err: errors.New("exec"),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "slowcall",
|
|
|
+ args: []interface{}{1},
|
|
|
+ delay: true,
|
|
|
+ lastInsertId: 1,
|
|
|
+ rowsAffected: 2,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, test := range tests {
|
|
|
+ test := test
|
|
|
+ fns := []func(args ...interface{}) (sql.Result, error){
|
|
|
+ func(args ...interface{}) (sql.Result, error) {
|
|
|
+ return exec(&mockedSessionConn{
|
|
|
+ lastInsertId: test.lastInsertId,
|
|
|
+ rowsAffected: test.rowsAffected,
|
|
|
+ err: test.err,
|
|
|
+ delay: test.delay,
|
|
|
+ }, "select user from users where id=?", args...)
|
|
|
+ },
|
|
|
+ func(args ...interface{}) (sql.Result, error) {
|
|
|
+ return execStmt(&mockedStmtConn{
|
|
|
+ lastInsertId: test.lastInsertId,
|
|
|
+ rowsAffected: test.rowsAffected,
|
|
|
+ err: test.err,
|
|
|
+ delay: test.delay,
|
|
|
+ }, args...)
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for i, fn := range fns {
|
|
|
+ i := i
|
|
|
+ fn := fn
|
|
|
+ t.Run(test.name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ res, err := fn(test.args...)
|
|
|
+ if i == 0 && test.formatError {
|
|
|
+ assert.NotNil(t, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if !test.formatError && test.hasError {
|
|
|
+ assert.NotNil(t, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ assert.Nil(t, err)
|
|
|
+ lastInsertId, err := res.LastInsertId()
|
|
|
+ assert.Nil(t, err)
|
|
|
+ assert.Equal(t, test.lastInsertId, lastInsertId)
|
|
|
+ rowsAffected, err := res.RowsAffected()
|
|
|
+ assert.Nil(t, err)
|
|
|
+ assert.Equal(t, test.rowsAffected, rowsAffected)
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestStmt_query(t *testing.T) {
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ args []interface{}
|
|
|
+ delay bool
|
|
|
+ formatError bool
|
|
|
+ hasError bool
|
|
|
+ err error
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "normal",
|
|
|
+ args: []interface{}{1},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "wrong format",
|
|
|
+ args: []interface{}{1, 2},
|
|
|
+ formatError: true,
|
|
|
+ hasError: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "query error",
|
|
|
+ args: []interface{}{1},
|
|
|
+ hasError: true,
|
|
|
+ err: errors.New("exec"),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "slowcall",
|
|
|
+ args: []interface{}{1},
|
|
|
+ delay: true,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, test := range tests {
|
|
|
+ test := test
|
|
|
+ fns := []func(args ...interface{}) error{
|
|
|
+ func(args ...interface{}) error {
|
|
|
+ return query(&mockedSessionConn{
|
|
|
+ err: test.err,
|
|
|
+ delay: test.delay,
|
|
|
+ }, func(rows *sql.Rows) error {
|
|
|
+ return nil
|
|
|
+ }, "select user from users where id=?", args...)
|
|
|
+ },
|
|
|
+ func(args ...interface{}) error {
|
|
|
+ return queryStmt(&mockedStmtConn{
|
|
|
+ err: test.err,
|
|
|
+ delay: test.delay,
|
|
|
+ }, func(rows *sql.Rows) error {
|
|
|
+ return nil
|
|
|
+ }, args...)
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for i, fn := range fns {
|
|
|
+ i := i
|
|
|
+ fn := fn
|
|
|
+ t.Run(test.name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ err := fn(test.args...)
|
|
|
+ if i == 0 && test.formatError {
|
|
|
+ assert.NotNil(t, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if !test.formatError && test.hasError {
|
|
|
+ assert.NotNil(t, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ assert.Equal(t, errMockedPlaceholder, err)
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type mockedSessionConn struct {
|
|
|
+ lastInsertId int64
|
|
|
+ rowsAffected int64
|
|
|
+ err error
|
|
|
+ delay bool
|
|
|
+}
|
|
|
+
|
|
|
+func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
|
+ if m.delay {
|
|
|
+ time.Sleep(slowThreshold + time.Millisecond)
|
|
|
+ }
|
|
|
+ return mockedResult{
|
|
|
+ lastInsertId: m.lastInsertId,
|
|
|
+ rowsAffected: m.rowsAffected,
|
|
|
+ }, m.err
|
|
|
+}
|
|
|
+
|
|
|
+func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
|
|
+ if m.delay {
|
|
|
+ time.Sleep(slowThreshold + time.Millisecond)
|
|
|
+ }
|
|
|
+
|
|
|
+ err := errMockedPlaceholder
|
|
|
+ if m.err != nil {
|
|
|
+ err = m.err
|
|
|
+ }
|
|
|
+ return new(sql.Rows), err
|
|
|
+}
|
|
|
+
|
|
|
+type mockedStmtConn struct {
|
|
|
+ lastInsertId int64
|
|
|
+ rowsAffected int64
|
|
|
+ err error
|
|
|
+ delay bool
|
|
|
+}
|
|
|
+
|
|
|
+func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
|
|
|
+ if m.delay {
|
|
|
+ time.Sleep(slowThreshold + time.Millisecond)
|
|
|
+ }
|
|
|
+ return mockedResult{
|
|
|
+ lastInsertId: m.lastInsertId,
|
|
|
+ rowsAffected: m.rowsAffected,
|
|
|
+ }, m.err
|
|
|
+}
|
|
|
+
|
|
|
+func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
|
|
|
+ if m.delay {
|
|
|
+ time.Sleep(slowThreshold + time.Millisecond)
|
|
|
+ }
|
|
|
+
|
|
|
+ err := errMockedPlaceholder
|
|
|
+ if m.err != nil {
|
|
|
+ err = m.err
|
|
|
+ }
|
|
|
+ return new(sql.Rows), err
|
|
|
+}
|
|
|
+
|
|
|
+type mockedResult struct {
|
|
|
+ lastInsertId int64
|
|
|
+ rowsAffected int64
|
|
|
+}
|
|
|
+
|
|
|
+func (m mockedResult) LastInsertId() (int64, error) {
|
|
|
+ return m.lastInsertId, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (m mockedResult) RowsAffected() (int64, error) {
|
|
|
+ return m.rowsAffected, nil
|
|
|
+}
|