123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- package sqlx
- import (
- "context"
- "database/sql"
- "errors"
- "testing"
- "time"
- "github.com/stretchr/testify/assert"
- )
- var errMockedPlaceholder = errors.New("placeholder")
- func TestStmt_exec(t *testing.T) {
- tests := []struct {
- name string
- query string
- args []any
- delay bool
- hasError bool
- err error
- lastInsertId int64
- rowsAffected int64
- }{
- {
- name: "normal",
- query: "select user from users where id=?",
- args: []any{1},
- lastInsertId: 1,
- rowsAffected: 2,
- },
- {
- name: "exec error",
- query: "select user from users where id=?",
- args: []any{1},
- hasError: true,
- err: errors.New("exec"),
- },
- {
- name: "exec more args error",
- query: "select user from users where id=? and name=?",
- args: []any{1},
- hasError: true,
- err: errors.New("exec"),
- },
- {
- name: "slowcall",
- query: "select user from users where id=?",
- args: []any{1},
- delay: true,
- lastInsertId: 1,
- rowsAffected: 2,
- },
- }
- for _, test := range tests {
- test := test
- fns := []func(args ...any) (sql.Result, error){
- func(args ...any) (sql.Result, error) {
- return exec(context.Background(), &mockedSessionConn{
- lastInsertId: test.lastInsertId,
- rowsAffected: test.rowsAffected,
- err: test.err,
- delay: test.delay,
- }, test.query, args...)
- },
- func(args ...any) (sql.Result, error) {
- return execStmt(context.Background(), &mockedStmtConn{
- lastInsertId: test.lastInsertId,
- rowsAffected: test.rowsAffected,
- err: test.err,
- delay: test.delay,
- }, test.query, args...)
- },
- }
- for _, fn := range fns {
- fn := fn
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
- res, err := fn(test.args...)
- if test.hasError {
- assert.NotNil(t, err)
- return
- }
- assert.Nil(t, err)
- lastInsertId, err := res.LastInsertId()
- assert.Nil(t, err)
- assert.Equal(t, test.lastInsertId, lastInsertId)
- rowsAffected, err := res.RowsAffected()
- assert.Nil(t, err)
- assert.Equal(t, test.rowsAffected, rowsAffected)
- })
- }
- }
- }
- func TestStmt_query(t *testing.T) {
- tests := []struct {
- name string
- query string
- args []any
- delay bool
- hasError bool
- err error
- }{
- {
- name: "normal",
- query: "select user from users where id=?",
- args: []any{1},
- },
- {
- name: "query error",
- query: "select user from users where id=?",
- args: []any{1},
- hasError: true,
- err: errors.New("exec"),
- },
- {
- name: "query more args error",
- query: "select user from users where id=? and name=?",
- args: []any{1},
- hasError: true,
- err: errors.New("exec"),
- },
- {
- name: "slowcall",
- query: "select user from users where id=?",
- args: []any{1},
- delay: true,
- },
- }
- for _, test := range tests {
- test := test
- fns := []func(args ...any) error{
- func(args ...any) error {
- return query(context.Background(), &mockedSessionConn{
- err: test.err,
- delay: test.delay,
- }, func(rows *sql.Rows) error {
- return nil
- }, test.query, args...)
- },
- func(args ...any) error {
- return queryStmt(context.Background(), &mockedStmtConn{
- err: test.err,
- delay: test.delay,
- }, func(rows *sql.Rows) error {
- return nil
- }, test.query, args...)
- },
- }
- for _, fn := range fns {
- fn := fn
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
- err := fn(test.args...)
- if test.hasError {
- assert.NotNil(t, err)
- return
- }
- assert.NotNil(t, err)
- })
- }
- }
- }
- func TestSetSlowThreshold(t *testing.T) {
- assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
- SetSlowThreshold(time.Second)
- assert.Equal(t, time.Second, slowThreshold.Load())
- }
- func TestDisableLog(t *testing.T) {
- assert.True(t, logSql.True())
- assert.True(t, logSlowSql.True())
- defer func() {
- logSql.Set(true)
- logSlowSql.Set(true)
- }()
- DisableLog()
- assert.False(t, logSql.True())
- assert.False(t, logSlowSql.True())
- }
- func TestDisableStmtLog(t *testing.T) {
- assert.True(t, logSql.True())
- assert.True(t, logSlowSql.True())
- defer func() {
- logSql.Set(true)
- logSlowSql.Set(true)
- }()
- DisableStmtLog()
- assert.False(t, logSql.True())
- assert.True(t, logSlowSql.True())
- }
- func TestNilGuard(t *testing.T) {
- assert.True(t, logSql.True())
- assert.True(t, logSlowSql.True())
- defer func() {
- logSql.Set(true)
- logSlowSql.Set(true)
- }()
- DisableLog()
- guard := newGuard("any")
- assert.Nil(t, guard.start("foo", "bar"))
- guard.finish(context.Background(), nil)
- assert.Equal(t, nilGuard{}, guard)
- }
- type mockedSessionConn struct {
- lastInsertId int64
- rowsAffected int64
- err error
- delay bool
- }
- func (m *mockedSessionConn) Exec(query string, args ...any) (sql.Result, error) {
- return m.ExecContext(context.Background(), query, args...)
- }
- func (m *mockedSessionConn) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
- if m.delay {
- time.Sleep(defaultSlowThreshold + time.Millisecond)
- }
- return mockedResult{
- lastInsertId: m.lastInsertId,
- rowsAffected: m.rowsAffected,
- }, m.err
- }
- func (m *mockedSessionConn) Query(query string, args ...any) (*sql.Rows, error) {
- return m.QueryContext(context.Background(), query, args...)
- }
- func (m *mockedSessionConn) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
- if m.delay {
- time.Sleep(defaultSlowThreshold + time.Millisecond)
- }
- err := errMockedPlaceholder
- if m.err != nil {
- err = m.err
- }
- return new(sql.Rows), err
- }
- type mockedStmtConn struct {
- lastInsertId int64
- rowsAffected int64
- err error
- delay bool
- }
- func (m *mockedStmtConn) Exec(args ...any) (sql.Result, error) {
- return m.ExecContext(context.Background(), args...)
- }
- func (m *mockedStmtConn) ExecContext(_ context.Context, _ ...any) (sql.Result, error) {
- if m.delay {
- time.Sleep(defaultSlowThreshold + time.Millisecond)
- }
- return mockedResult{
- lastInsertId: m.lastInsertId,
- rowsAffected: m.rowsAffected,
- }, m.err
- }
- func (m *mockedStmtConn) Query(args ...any) (*sql.Rows, error) {
- return m.QueryContext(context.Background(), args...)
- }
- func (m *mockedStmtConn) QueryContext(_ context.Context, _ ...any) (*sql.Rows, error) {
- if m.delay {
- time.Sleep(defaultSlowThreshold + time.Millisecond)
- }
- err := errMockedPlaceholder
- if m.err != nil {
- err = m.err
- }
- return new(sql.Rows), err
- }
- type mockedResult struct {
- lastInsertId int64
- rowsAffected int64
- }
- func (m mockedResult) LastInsertId() (int64, error) {
- return m.lastInsertId, nil
- }
- func (m mockedResult) RowsAffected() (int64, error) {
- return m.rowsAffected, nil
- }
|