浏览代码

feat: support ctx in sqlx/sqlc, listed in ROADMAP (#1535)

* feat: support ctx in sqlx/sqlc

* chore: update roadmap

* fix: context.Canceled should be acceptable

* use %w to wrap errors

* chore: remove unused vars
Kevin Wan 3 年之前
父节点
当前提交
607bae27fa

+ 110 - 27
core/stores/sqlc/cachedsql.go

@@ -1,6 +1,7 @@
 package sqlc
 package sqlc
 
 
 import (
 import (
+	"context"
 	"database/sql"
 	"database/sql"
 	"time"
 	"time"
 
 
@@ -18,19 +19,27 @@ var (
 	ErrNotFound = sqlx.ErrNotFound
 	ErrNotFound = sqlx.ErrNotFound
 
 
 	// can't use one SingleFlight per conn, because multiple conns may share the same cache key.
 	// can't use one SingleFlight per conn, because multiple conns may share the same cache key.
-	exclusiveCalls = syncx.NewSingleFlight()
-	stats          = cache.NewStat("sqlc")
+	singleFlights = syncx.NewSingleFlight()
+	stats         = cache.NewStat("sqlc")
 )
 )
 
 
 type (
 type (
 	// ExecFn defines the sql exec method.
 	// ExecFn defines the sql exec method.
 	ExecFn func(conn sqlx.SqlConn) (sql.Result, error)
 	ExecFn func(conn sqlx.SqlConn) (sql.Result, error)
+	// ExecCtxFn defines the sql exec method.
+	ExecCtxFn func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error)
 	// IndexQueryFn defines the query method that based on unique indexes.
 	// IndexQueryFn defines the query method that based on unique indexes.
 	IndexQueryFn func(conn sqlx.SqlConn, v interface{}) (interface{}, error)
 	IndexQueryFn func(conn sqlx.SqlConn, v interface{}) (interface{}, error)
+	// IndexQueryCtxFn defines the query method that based on unique indexes.
+	IndexQueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v interface{}) (interface{}, error)
 	// PrimaryQueryFn defines the query method that based on primary keys.
 	// PrimaryQueryFn defines the query method that based on primary keys.
 	PrimaryQueryFn func(conn sqlx.SqlConn, v, primary interface{}) error
 	PrimaryQueryFn func(conn sqlx.SqlConn, v, primary interface{}) error
+	// PrimaryQueryCtxFn defines the query method that based on primary keys.
+	PrimaryQueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v, primary interface{}) error
 	// QueryFn defines the query method.
 	// QueryFn defines the query method.
 	QueryFn func(conn sqlx.SqlConn, v interface{}) error
 	QueryFn func(conn sqlx.SqlConn, v interface{}) error
+	// QueryCtxFn defines the query method.
+	QueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v interface{}) error
 
 
 	// A CachedConn is a DB connection with cache capability.
 	// A CachedConn is a DB connection with cache capability.
 	CachedConn struct {
 	CachedConn struct {
@@ -41,7 +50,7 @@ type (
 
 
 // NewConn returns a CachedConn with a redis cluster cache.
 // NewConn returns a CachedConn with a redis cluster cache.
 func NewConn(db sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) CachedConn {
 func NewConn(db sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) CachedConn {
-	cc := cache.New(c, exclusiveCalls, stats, sql.ErrNoRows, opts...)
+	cc := cache.New(c, singleFlights, stats, sql.ErrNoRows, opts...)
 	return NewConnWithCache(db, cc)
 	return NewConnWithCache(db, cc)
 }
 }
 
 
@@ -55,28 +64,46 @@ func NewConnWithCache(db sqlx.SqlConn, c cache.Cache) CachedConn {
 
 
 // NewNodeConn returns a CachedConn with a redis node cache.
 // NewNodeConn returns a CachedConn with a redis node cache.
 func NewNodeConn(db sqlx.SqlConn, rds *redis.Redis, opts ...cache.Option) CachedConn {
 func NewNodeConn(db sqlx.SqlConn, rds *redis.Redis, opts ...cache.Option) CachedConn {
-	c := cache.NewNode(rds, exclusiveCalls, stats, sql.ErrNoRows, opts...)
+	c := cache.NewNode(rds, singleFlights, stats, sql.ErrNoRows, opts...)
 	return NewConnWithCache(db, c)
 	return NewConnWithCache(db, c)
 }
 }
 
 
 // DelCache deletes cache with keys.
 // DelCache deletes cache with keys.
 func (cc CachedConn) DelCache(keys ...string) error {
 func (cc CachedConn) DelCache(keys ...string) error {
-	return cc.cache.Del(keys...)
+	return cc.DelCacheCtx(context.Background(), keys...)
+}
+
+// DelCacheCtx deletes cache with keys.
+func (cc CachedConn) DelCacheCtx(ctx context.Context, keys ...string) error {
+	return cc.cache.DelCtx(ctx, keys...)
 }
 }
 
 
 // GetCache unmarshals cache with given key into v.
 // GetCache unmarshals cache with given key into v.
 func (cc CachedConn) GetCache(key string, v interface{}) error {
 func (cc CachedConn) GetCache(key string, v interface{}) error {
-	return cc.cache.Get(key, v)
+	return cc.GetCacheCtx(context.Background(), key, v)
+}
+
+// GetCacheCtx unmarshals cache with given key into v.
+func (cc CachedConn) GetCacheCtx(ctx context.Context, key string, v interface{}) error {
+	return cc.cache.GetCtx(ctx, key, v)
 }
 }
 
 
 // Exec runs given exec on given keys, and returns execution result.
 // Exec runs given exec on given keys, and returns execution result.
 func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) {
 func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) {
-	res, err := exec(cc.db)
+	execCtx := func(_ context.Context, conn sqlx.SqlConn) (sql.Result, error) {
+		return exec(conn)
+	}
+	return cc.ExecCtx(context.Background(), execCtx, keys...)
+}
+
+// ExecCtx runs given exec on given keys, and returns execution result.
+func (cc CachedConn) ExecCtx(ctx context.Context, exec ExecCtxFn, keys ...string) (sql.Result, error) {
+	res, err := exec(ctx, cc.db)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	if err := cc.DelCache(keys...); err != nil {
+	if err := cc.DelCacheCtx(ctx, keys...); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
@@ -85,31 +112,61 @@ func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) {
 
 
 // ExecNoCache runs exec with given sql statement, without affecting cache.
 // ExecNoCache runs exec with given sql statement, without affecting cache.
 func (cc CachedConn) ExecNoCache(q string, args ...interface{}) (sql.Result, error) {
 func (cc CachedConn) ExecNoCache(q string, args ...interface{}) (sql.Result, error) {
-	return cc.db.Exec(q, args...)
+	return cc.ExecNoCacheCtx(context.Background(), q, args...)
+}
+
+// ExecNoCacheCtx runs exec with given sql statement, without affecting cache.
+func (cc CachedConn) ExecNoCacheCtx(ctx context.Context, q string, args ...interface{}) (
+	sql.Result, error) {
+	return cc.db.ExecCtx(ctx, q, args...)
 }
 }
 
 
 // QueryRow unmarshals into v with given key and query func.
 // QueryRow unmarshals into v with given key and query func.
 func (cc CachedConn) QueryRow(v interface{}, key string, query QueryFn) error {
 func (cc CachedConn) QueryRow(v interface{}, key string, query QueryFn) error {
-	return cc.cache.Take(v, key, func(v interface{}) error {
-		return query(cc.db, v)
+	queryCtx := func(_ context.Context, conn sqlx.SqlConn, v interface{}) error {
+		return query(conn, v)
+	}
+	return cc.QueryRowCtx(context.Background(), v, key, queryCtx)
+}
+
+// QueryRowCtx unmarshals into v with given key and query func.
+func (cc CachedConn) QueryRowCtx(ctx context.Context, v interface{}, key string, query QueryCtxFn) error {
+	return cc.cache.TakeCtx(ctx, v, key, func(v interface{}) error {
+		return query(ctx, cc.db, v)
 	})
 	})
 }
 }
 
 
 // QueryRowIndex unmarshals into v with given key.
 // QueryRowIndex unmarshals into v with given key.
 func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary interface{}) string,
 func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary interface{}) string,
 	indexQuery IndexQueryFn, primaryQuery PrimaryQueryFn) error {
 	indexQuery IndexQueryFn, primaryQuery PrimaryQueryFn) error {
+	indexQueryCtx := func(_ context.Context, conn sqlx.SqlConn, v interface{}) (interface{}, error) {
+		return indexQuery(conn, v)
+	}
+	primaryQueryCtx := func(_ context.Context, conn sqlx.SqlConn, v, primary interface{}) error {
+		return primaryQuery(conn, v, primary)
+	}
+
+	return cc.QueryRowIndexCtx(context.Background(), v, key, keyer, indexQueryCtx, primaryQueryCtx)
+}
+
+// QueryRowIndexCtx unmarshals into v with given key.
+func (cc CachedConn) QueryRowIndexCtx(ctx context.Context, v interface{}, key string,
+	keyer func(primary interface{}) string, indexQuery IndexQueryCtxFn,
+	primaryQuery PrimaryQueryCtxFn) error {
 	var primaryKey interface{}
 	var primaryKey interface{}
 	var found bool
 	var found bool
 
 
-	if err := cc.cache.TakeWithExpire(&primaryKey, key, func(val interface{}, expire time.Duration) (err error) {
-		primaryKey, err = indexQuery(cc.db, v)
-		if err != nil {
-			return
-		}
-
-		found = true
-		return cc.cache.SetWithExpire(keyer(primaryKey), v, expire+cacheSafeGapBetweenIndexAndPrimary)
-	}); err != nil {
+	if err := cc.cache.TakeWithExpireCtx(ctx, &primaryKey, key,
+		func(val interface{}, expire time.Duration) (err error) {
+			primaryKey, err = indexQuery(ctx, cc.db, v)
+			if err != nil {
+				return
+			}
+
+			found = true
+			return cc.cache.SetWithExpireCtx(ctx, keyer(primaryKey), v,
+				expire+cacheSafeGapBetweenIndexAndPrimary)
+		}); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -117,28 +174,54 @@ func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary
 		return nil
 		return nil
 	}
 	}
 
 
-	return cc.cache.Take(v, keyer(primaryKey), func(v interface{}) error {
-		return primaryQuery(cc.db, v, primaryKey)
+	return cc.cache.TakeCtx(ctx, v, keyer(primaryKey), func(v interface{}) error {
+		return primaryQuery(ctx, cc.db, v, primaryKey)
 	})
 	})
 }
 }
 
 
 // QueryRowNoCache unmarshals into v with given statement.
 // QueryRowNoCache unmarshals into v with given statement.
 func (cc CachedConn) QueryRowNoCache(v interface{}, q string, args ...interface{}) error {
 func (cc CachedConn) QueryRowNoCache(v interface{}, q string, args ...interface{}) error {
-	return cc.db.QueryRow(v, q, args...)
+	return cc.QueryRowNoCacheCtx(context.Background(), v, q, args...)
+}
+
+// QueryRowNoCacheCtx unmarshals into v with given statement.
+func (cc CachedConn) QueryRowNoCacheCtx(ctx context.Context, v interface{}, q string,
+	args ...interface{}) error {
+	return cc.db.QueryRowCtx(ctx, v, q, args...)
 }
 }
 
 
 // QueryRowsNoCache unmarshals into v with given statement.
 // QueryRowsNoCache unmarshals into v with given statement.
 // It doesn't use cache, because it might cause consistency problem.
 // It doesn't use cache, because it might cause consistency problem.
 func (cc CachedConn) QueryRowsNoCache(v interface{}, q string, args ...interface{}) error {
 func (cc CachedConn) QueryRowsNoCache(v interface{}, q string, args ...interface{}) error {
-	return cc.db.QueryRows(v, q, args...)
+	return cc.QueryRowsNoCacheCtx(context.Background(), v, q, args...)
+}
+
+// QueryRowsNoCacheCtx unmarshals into v with given statement.
+// It doesn't use cache, because it might cause consistency problem.
+func (cc CachedConn) QueryRowsNoCacheCtx(ctx context.Context, v interface{}, q string,
+	args ...interface{}) error {
+	return cc.db.QueryRowsCtx(ctx, v, q, args...)
 }
 }
 
 
 // SetCache sets v into cache with given key.
 // SetCache sets v into cache with given key.
-func (cc CachedConn) SetCache(key string, v interface{}) error {
-	return cc.cache.Set(key, v)
+func (cc CachedConn) SetCache(key string, val interface{}) error {
+	return cc.SetCacheCtx(context.Background(), key, val)
+}
+
+// SetCacheCtx sets v into cache with given key.
+func (cc CachedConn) SetCacheCtx(ctx context.Context, key string, val interface{}) error {
+	return cc.cache.SetCtx(ctx, key, val)
 }
 }
 
 
 // Transact runs given fn in transaction mode.
 // Transact runs given fn in transaction mode.
 func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
 func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
-	return cc.db.Transact(fn)
+	fnCtx := func(_ context.Context, session sqlx.Session) error {
+		return fn(session)
+	}
+	return cc.TransactCtx(context.Background(), fnCtx)
+}
+
+// TransactCtx runs given fn in transaction mode.
+func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
+	return cc.db.TransactCtx(ctx, fn)
 }
 }

+ 47 - 4
core/stores/sqlc/cachedsql_test.go

@@ -1,6 +1,7 @@
 package sqlc
 package sqlc
 
 
 import (
 import (
+	"context"
 	"database/sql"
 	"database/sql"
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
@@ -568,7 +569,7 @@ func TestNewConnWithCache(t *testing.T) {
 	defer clean()
 	defer clean()
 
 
 	var conn trackedConn
 	var conn trackedConn
-	c := NewConnWithCache(&conn, cache.NewNode(r, exclusiveCalls, stats, sql.ErrNoRows))
+	c := NewConnWithCache(&conn, cache.NewNode(r, singleFlights, stats, sql.ErrNoRows))
 	_, err = c.ExecNoCache("delete from user_table where id='kevin'")
 	_, err = c.ExecNoCache("delete from user_table where id='kevin'")
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.True(t, conn.execValue)
 	assert.True(t, conn.execValue)
@@ -585,6 +586,30 @@ type dummySqlConn struct {
 	queryRow func(interface{}, string, ...interface{}) error
 	queryRow func(interface{}, string, ...interface{}) error
 }
 }
 
 
+func (d dummySqlConn) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
+	return nil, nil
+}
+
+func (d dummySqlConn) PrepareCtx(ctx context.Context, query string) (sqlx.StmtSession, error) {
+	return nil, nil
+}
+
+func (d dummySqlConn) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
+	return nil
+}
+
+func (d dummySqlConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
+	return nil
+}
+
+func (d dummySqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
+	return nil
+}
+
+func (d dummySqlConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
+	return nil
+}
+
 func (d dummySqlConn) Exec(query string, args ...interface{}) (sql.Result, error) {
 func (d dummySqlConn) Exec(query string, args ...interface{}) (sql.Result, error) {
 	return nil, nil
 	return nil, nil
 }
 }
@@ -594,6 +619,10 @@ func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) {
 }
 }
 
 
 func (d dummySqlConn) QueryRow(v interface{}, query string, args ...interface{}) error {
 func (d dummySqlConn) QueryRow(v interface{}, query string, args ...interface{}) error {
+	return d.QueryRowCtx(context.Background(), v, query, args...)
+}
+
+func (d dummySqlConn) QueryRowCtx(_ context.Context, v interface{}, query string, args ...interface{}) error {
 	if d.queryRow != nil {
 	if d.queryRow != nil {
 		return d.queryRow(v, query, args...)
 		return d.queryRow(v, query, args...)
 	}
 	}
@@ -628,13 +657,21 @@ type trackedConn struct {
 }
 }
 
 
 func (c *trackedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
 func (c *trackedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
+	return c.ExecCtx(context.Background(), query, args...)
+}
+
+func (c *trackedConn) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
 	c.execValue = true
 	c.execValue = true
-	return c.dummySqlConn.Exec(query, args...)
+	return c.dummySqlConn.ExecCtx(ctx, query, args...)
 }
 }
 
 
 func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
 func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
+	return c.QueryRowsCtx(context.Background(), v, query, args...)
+}
+
+func (c *trackedConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
 	c.queryRowsValue = true
 	c.queryRowsValue = true
-	return c.dummySqlConn.QueryRows(v, query, args...)
+	return c.dummySqlConn.QueryRowsCtx(ctx, v, query, args...)
 }
 }
 
 
 func (c *trackedConn) RawDB() (*sql.DB, error) {
 func (c *trackedConn) RawDB() (*sql.DB, error) {
@@ -642,6 +679,12 @@ func (c *trackedConn) RawDB() (*sql.DB, error) {
 }
 }
 
 
 func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
 func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
+	return c.TransactCtx(context.Background(), func(_ context.Context, session sqlx.Session) error {
+		return fn(session)
+	})
+}
+
+func (c *trackedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
 	c.transactValue = true
 	c.transactValue = true
-	return c.dummySqlConn.Transact(fn)
+	return c.dummySqlConn.TransactCtx(ctx, fn)
 }
 }

+ 30 - 1
core/stores/sqlx/bulkinserter_test.go

@@ -1,6 +1,7 @@
 package sqlx
 package sqlx
 
 
 import (
 import (
+	"context"
 	"database/sql"
 	"database/sql"
 	"errors"
 	"errors"
 	"strconv"
 	"strconv"
@@ -17,12 +18,40 @@ type mockedConn struct {
 	execErr error
 	execErr error
 }
 }
 
 
-func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
+func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...interface{}) (sql.Result, error) {
 	c.query = query
 	c.query = query
 	c.args = args
 	c.args = args
 	return nil, c.execErr
 	return nil, c.execErr
 }
 }
 
 
+func (c *mockedConn) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
+	panic("implement me")
+}
+
+func (c *mockedConn) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
+	panic("implement me")
+}
+
+func (c *mockedConn) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
+	panic("implement me")
+}
+
+func (c *mockedConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
+	panic("implement me")
+}
+
+func (c *mockedConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
+	panic("implement me")
+}
+
+func (c *mockedConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
+	panic("should not called")
+}
+
+func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
+	return c.ExecCtx(context.Background(), query, args...)
+}
+
 func (c *mockedConn) Prepare(query string) (StmtSession, error) {
 func (c *mockedConn) Prepare(query string) (StmtSession, error) {
 	panic("should not called")
 	panic("should not called")
 }
 }

+ 57 - 56
core/stores/sqlx/orm_test.go

@@ -1,6 +1,7 @@
 package sqlx
 package sqlx
 
 
 import (
 import (
+	"context"
 	"database/sql"
 	"database/sql"
 	"errors"
 	"errors"
 	"testing"
 	"testing"
@@ -16,7 +17,7 @@ func TestUnmarshalRowBool(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value bool
 		var value bool
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.True(t, value)
 		assert.True(t, value)
@@ -29,7 +30,7 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value bool
 		var value bool
-		assert.NotNil(t, query(db, func(rows *sql.Rows) error {
+		assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(value, rows, true)
 			return unmarshalRow(value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 	})
 	})
@@ -41,7 +42,7 @@ func TestUnmarshalRowInt(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value int
 		var value int
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, 2, value)
 		assert.EqualValues(t, 2, value)
@@ -54,7 +55,7 @@ func TestUnmarshalRowInt8(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value int8
 		var value int8
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, int8(3), value)
 		assert.EqualValues(t, int8(3), value)
@@ -67,7 +68,7 @@ func TestUnmarshalRowInt16(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value int16
 		var value int16
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.Equal(t, int16(4), value)
 		assert.Equal(t, int16(4), value)
@@ -80,7 +81,7 @@ func TestUnmarshalRowInt32(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value int32
 		var value int32
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.Equal(t, int32(5), value)
 		assert.Equal(t, int32(5), value)
@@ -93,7 +94,7 @@ func TestUnmarshalRowInt64(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value int64
 		var value int64
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, int64(6), value)
 		assert.EqualValues(t, int64(6), value)
@@ -106,7 +107,7 @@ func TestUnmarshalRowUint(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value uint
 		var value uint
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, uint(2), value)
 		assert.EqualValues(t, uint(2), value)
@@ -119,7 +120,7 @@ func TestUnmarshalRowUint8(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value uint8
 		var value uint8
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, uint8(3), value)
 		assert.EqualValues(t, uint8(3), value)
@@ -132,7 +133,7 @@ func TestUnmarshalRowUint16(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value uint16
 		var value uint16
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, uint16(4), value)
 		assert.EqualValues(t, uint16(4), value)
@@ -145,7 +146,7 @@ func TestUnmarshalRowUint32(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value uint32
 		var value uint32
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, uint32(5), value)
 		assert.EqualValues(t, uint32(5), value)
@@ -158,7 +159,7 @@ func TestUnmarshalRowUint64(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value uint64
 		var value uint64
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, uint16(6), value)
 		assert.EqualValues(t, uint16(6), value)
@@ -171,7 +172,7 @@ func TestUnmarshalRowFloat32(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value float32
 		var value float32
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, float32(7), value)
 		assert.EqualValues(t, float32(7), value)
@@ -184,7 +185,7 @@ func TestUnmarshalRowFloat64(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value float64
 		var value float64
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, float64(8), value)
 		assert.EqualValues(t, float64(8), value)
@@ -198,7 +199,7 @@ func TestUnmarshalRowString(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value string
 		var value string
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&value, rows, true)
 			return unmarshalRow(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -215,7 +216,7 @@ func TestUnmarshalRowStruct(t *testing.T) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(value, rows, true)
 			return unmarshalRow(value, rows, true)
 		}, "select name, age from users where user=?", "anyone"))
 		}, "select name, age from users where user=?", "anyone"))
 		assert.Equal(t, "liao", value.Name)
 		assert.Equal(t, "liao", value.Name)
@@ -233,7 +234,7 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(value, rows, true)
 			return unmarshalRow(value, rows, true)
 		}, "select name, age from users where user=?", "anyone"))
 		}, "select name, age from users where user=?", "anyone"))
 		assert.Equal(t, "liao", value.Name)
 		assert.Equal(t, "liao", value.Name)
@@ -251,7 +252,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
 		rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
 		rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
-		assert.NotNil(t, query(db, func(rows *sql.Rows) error {
+		assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(value, rows, true)
 			return unmarshalRow(value, rows, true)
 		}, "select name, age from users where user=?", "anyone"))
 		}, "select name, age from users where user=?", "anyone"))
 	})
 	})
@@ -264,7 +265,7 @@ func TestUnmarshalRowsBool(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []bool
 		var value []bool
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -278,7 +279,7 @@ func TestUnmarshalRowsInt(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []int
 		var value []int
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -292,7 +293,7 @@ func TestUnmarshalRowsInt8(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []int8
 		var value []int8
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -306,7 +307,7 @@ func TestUnmarshalRowsInt16(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []int16
 		var value []int16
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -320,7 +321,7 @@ func TestUnmarshalRowsInt32(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []int32
 		var value []int32
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -334,7 +335,7 @@ func TestUnmarshalRowsInt64(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []int64
 		var value []int64
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -348,7 +349,7 @@ func TestUnmarshalRowsUint(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []uint
 		var value []uint
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -362,7 +363,7 @@ func TestUnmarshalRowsUint8(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []uint8
 		var value []uint8
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -376,7 +377,7 @@ func TestUnmarshalRowsUint16(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []uint16
 		var value []uint16
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -390,7 +391,7 @@ func TestUnmarshalRowsUint32(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []uint32
 		var value []uint32
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -404,7 +405,7 @@ func TestUnmarshalRowsUint64(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []uint64
 		var value []uint64
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -418,7 +419,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []float32
 		var value []float32
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -432,7 +433,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []float64
 		var value []float64
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -446,7 +447,7 @@ func TestUnmarshalRowsString(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []string
 		var value []string
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -462,7 +463,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*bool
 		var value []*bool
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -478,7 +479,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*int
 		var value []*int
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -494,7 +495,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*int8
 		var value []*int8
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -510,7 +511,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*int16
 		var value []*int16
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -526,7 +527,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*int32
 		var value []*int32
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -542,7 +543,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*int64
 		var value []*int64
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -558,7 +559,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*uint
 		var value []*uint
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -574,7 +575,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*uint8
 		var value []*uint8
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -590,7 +591,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*uint16
 		var value []*uint16
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -606,7 +607,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*uint32
 		var value []*uint32
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -622,7 +623,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*uint64
 		var value []*uint64
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -638,7 +639,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*float32
 		var value []*float32
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -654,7 +655,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*float64
 		var value []*float64
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -670,7 +671,7 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 
 		var value []*string
 		var value []*string
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select value from users where user=?", "anyone"))
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 		assert.EqualValues(t, expect, value)
@@ -699,7 +700,7 @@ func TestUnmarshalRowsStruct(t *testing.T) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select name, age from users where user=?", "anyone"))
 		}, "select name, age from users where user=?", "anyone"))
 
 
@@ -739,7 +740,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
 		rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
 		rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
 			"first", "firstnullstring").AddRow("second", nil)
 			"first", "firstnullstring").AddRow("second", nil)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select name, age from users where user=?", "anyone"))
 		}, "select name, age from users where user=?", "anyone"))
 
 
@@ -773,7 +774,7 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select name, age from users where user=?", "anyone"))
 		}, "select name, age from users where user=?", "anyone"))
 
 
@@ -814,7 +815,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
 		rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select name, age, value from users where user=?", "anyone"))
 		}, "select name, age, value from users where user=?", "anyone"))
 
 
@@ -856,7 +857,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
 		rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select name, age, value from users where user=?", "anyone"))
 		}, "select name, age, value from users where user=?", "anyone"))
 
 
@@ -890,7 +891,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select name, age from users where user=?", "anyone"))
 		}, "select name, age from users where user=?", "anyone"))
 
 
@@ -923,7 +924,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select name, age from users where user=?", "anyone"))
 		}, "select name, age from users where user=?", "anyone"))
 
 
@@ -956,7 +957,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
 			return unmarshalRows(&value, rows, true)
 		}, "select name, age from users where user=?", "anyone"))
 		}, "select name, age from users where user=?", "anyone"))
 
 
@@ -976,7 +977,7 @@ func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
 			User string `db:"user"`
 			User string `db:"user"`
 			Age  int    `db:"age"`
 			Age  int    `db:"age"`
 		}
 		}
-		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(&r, rows, false)
 			return unmarshalRow(&r, rows, false)
 		}, "select age from users where user=?", "anyone"))
 		}, "select age from users where user=?", "anyone"))
 		assert.Empty(t, r.User)
 		assert.Empty(t, r.User)
@@ -1027,7 +1028,7 @@ func TestUnmarshalRowError(t *testing.T) {
 					User string `db:"user"`
 					User string `db:"user"`
 					Age  int    `db:"age"`
 					Age  int    `db:"age"`
 				}
 				}
-				test.validate(query(db, func(rows *sql.Rows) error {
+				test.validate(query(context.Background(), db, func(rows *sql.Rows) error {
 					scanner := mockedScanner{
 					scanner := mockedScanner{
 						colErr:  test.colErr,
 						colErr:  test.colErr,
 						scanErr: test.scanErr,
 						scanErr: test.scanErr,

+ 89 - 16
core/stores/sqlx/sqlconn.go

@@ -1,6 +1,7 @@
 package sqlx
 package sqlx
 
 
 import (
 import (
+	"context"
 	"database/sql"
 	"database/sql"
 
 
 	"github.com/zeromicro/go-zero/core/breaker"
 	"github.com/zeromicro/go-zero/core/breaker"
@@ -14,11 +15,17 @@ type (
 	// Session stands for raw connections or transaction sessions
 	// Session stands for raw connections or transaction sessions
 	Session interface {
 	Session interface {
 		Exec(query string, args ...interface{}) (sql.Result, error)
 		Exec(query string, args ...interface{}) (sql.Result, error)
+		ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
 		Prepare(query string) (StmtSession, error)
 		Prepare(query string) (StmtSession, error)
+		PrepareCtx(ctx context.Context, query string) (StmtSession, error)
 		QueryRow(v interface{}, query string, args ...interface{}) error
 		QueryRow(v interface{}, query string, args ...interface{}) error
+		QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
 		QueryRowPartial(v interface{}, query string, args ...interface{}) error
 		QueryRowPartial(v interface{}, query string, args ...interface{}) error
+		QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
 		QueryRows(v interface{}, query string, args ...interface{}) error
 		QueryRows(v interface{}, query string, args ...interface{}) error
+		QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
 		QueryRowsPartial(v interface{}, query string, args ...interface{}) error
 		QueryRowsPartial(v interface{}, query string, args ...interface{}) error
+		QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
 	}
 	}
 
 
 	// SqlConn only stands for raw connections, so Transact method can be called.
 	// SqlConn only stands for raw connections, so Transact method can be called.
@@ -27,7 +34,8 @@ type (
 		// RawDB is for other ORM to operate with, use it with caution.
 		// RawDB is for other ORM to operate with, use it with caution.
 		// Notice: don't close it.
 		// Notice: don't close it.
 		RawDB() (*sql.DB, error)
 		RawDB() (*sql.DB, error)
-		Transact(func(session Session) error) error
+		Transact(fn func(Session) error) error
+		TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
 	}
 	}
 
 
 	// SqlOption defines the method to customize a sql connection.
 	// SqlOption defines the method to customize a sql connection.
@@ -37,10 +45,15 @@ type (
 	StmtSession interface {
 	StmtSession interface {
 		Close() error
 		Close() error
 		Exec(args ...interface{}) (sql.Result, error)
 		Exec(args ...interface{}) (sql.Result, error)
+		ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error)
 		QueryRow(v interface{}, args ...interface{}) error
 		QueryRow(v interface{}, args ...interface{}) error
+		QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error
 		QueryRowPartial(v interface{}, args ...interface{}) error
 		QueryRowPartial(v interface{}, args ...interface{}) error
+		QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
 		QueryRows(v interface{}, args ...interface{}) error
 		QueryRows(v interface{}, args ...interface{}) error
+		QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error
 		QueryRowsPartial(v interface{}, args ...interface{}) error
 		QueryRowsPartial(v interface{}, args ...interface{}) error
+		QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
 	}
 	}
 
 
 	// thread-safe
 	// thread-safe
@@ -58,7 +71,9 @@ type (
 
 
 	sessionConn interface {
 	sessionConn interface {
 		Exec(query string, args ...interface{}) (sql.Result, error)
 		Exec(query string, args ...interface{}) (sql.Result, error)
+		ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
 		Query(query string, args ...interface{}) (*sql.Rows, error)
 		Query(query string, args ...interface{}) (*sql.Rows, error)
+		QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
 	}
 	}
 
 
 	statement struct {
 	statement struct {
@@ -68,7 +83,9 @@ type (
 
 
 	stmtConn interface {
 	stmtConn interface {
 		Exec(args ...interface{}) (sql.Result, error)
 		Exec(args ...interface{}) (sql.Result, error)
+		ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
 		Query(args ...interface{}) (*sql.Rows, error)
 		Query(args ...interface{}) (*sql.Rows, error)
+		QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
 	}
 	}
 )
 )
 
 
@@ -112,6 +129,11 @@ func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
 }
 }
 
 
 func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
 func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
+	return db.ExecCtx(context.Background(), q, args...)
+}
+
+func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...interface{}) (
+	result sql.Result, err error) {
 	err = db.brk.DoWithAcceptable(func() error {
 	err = db.brk.DoWithAcceptable(func() error {
 		var conn *sql.DB
 		var conn *sql.DB
 		conn, err = db.connProv()
 		conn, err = db.connProv()
@@ -120,7 +142,7 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result,
 			return err
 			return err
 		}
 		}
 
 
-		result, err = exec(conn, q, args...)
+		result, err = exec(ctx, conn, q, args...)
 		return err
 		return err
 	}, db.acceptable)
 	}, db.acceptable)
 
 
@@ -128,6 +150,10 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result,
 }
 }
 
 
 func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
 func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
+	return db.PrepareCtx(context.Background(), query)
+}
+
+func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
 	err = db.brk.DoWithAcceptable(func() error {
 	err = db.brk.DoWithAcceptable(func() error {
 		var conn *sql.DB
 		var conn *sql.DB
 		conn, err = db.connProv()
 		conn, err = db.connProv()
@@ -136,7 +162,7 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
 			return err
 			return err
 		}
 		}
 
 
-		st, err := conn.Prepare(query)
+		st, err := conn.PrepareContext(ctx, query)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -152,25 +178,45 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
 }
 }
 
 
 func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
 func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
-	return db.queryRows(func(rows *sql.Rows) error {
+	return db.QueryRowCtx(context.Background(), v, q, args...)
+}
+
+func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v interface{}, q string,
+	args ...interface{}) error {
+	return db.queryRows(ctx, func(rows *sql.Rows) error {
 		return unmarshalRow(v, rows, true)
 		return unmarshalRow(v, rows, true)
 	}, q, args...)
 	}, q, args...)
 }
 }
 
 
 func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
 func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
-	return db.queryRows(func(rows *sql.Rows) error {
+	return db.QueryRowPartialCtx(context.Background(), v, q, args...)
+}
+
+func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v interface{},
+	q string, args ...interface{}) error {
+	return db.queryRows(ctx, func(rows *sql.Rows) error {
 		return unmarshalRow(v, rows, false)
 		return unmarshalRow(v, rows, false)
 	}, q, args...)
 	}, q, args...)
 }
 }
 
 
 func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
 func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
-	return db.queryRows(func(rows *sql.Rows) error {
+	return db.QueryRowsCtx(context.Background(), v, q, args...)
+}
+
+func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v interface{}, q string,
+	args ...interface{}) error {
+	return db.queryRows(ctx, func(rows *sql.Rows) error {
 		return unmarshalRows(v, rows, true)
 		return unmarshalRows(v, rows, true)
 	}, q, args...)
 	}, q, args...)
 }
 }
 
 
 func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
 func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
-	return db.queryRows(func(rows *sql.Rows) error {
+	return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
+}
+
+func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{},
+	q string, args ...interface{}) error {
+	return db.queryRows(ctx, func(rows *sql.Rows) error {
 		return unmarshalRows(v, rows, false)
 		return unmarshalRows(v, rows, false)
 	}, q, args...)
 	}, q, args...)
 }
 }
@@ -180,13 +226,19 @@ func (db *commonSqlConn) RawDB() (*sql.DB, error) {
 }
 }
 
 
 func (db *commonSqlConn) Transact(fn func(Session) error) error {
 func (db *commonSqlConn) Transact(fn func(Session) error) error {
+	return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
+		return fn(session)
+	})
+}
+
+func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
 	return db.brk.DoWithAcceptable(func() error {
 	return db.brk.DoWithAcceptable(func() error {
-		return transact(db, db.beginTx, fn)
+		return transact(ctx, db, db.beginTx, fn)
 	}, db.acceptable)
 	}, db.acceptable)
 }
 }
 
 
 func (db *commonSqlConn) acceptable(err error) bool {
 func (db *commonSqlConn) acceptable(err error) bool {
-	ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone
+	ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
 	if db.accept == nil {
 	if db.accept == nil {
 		return ok
 		return ok
 	}
 	}
@@ -194,7 +246,8 @@ func (db *commonSqlConn) acceptable(err error) bool {
 	return ok || db.accept(err)
 	return ok || db.accept(err)
 }
 }
 
 
-func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
+func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
+	q string, args ...interface{}) error {
 	var qerr error
 	var qerr error
 	return db.brk.DoWithAcceptable(func() error {
 	return db.brk.DoWithAcceptable(func() error {
 		conn, err := db.connProv()
 		conn, err := db.connProv()
@@ -203,7 +256,7 @@ func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args
 			return err
 			return err
 		}
 		}
 
 
-		return query(conn, func(rows *sql.Rows) error {
+		return query(ctx, conn, func(rows *sql.Rows) error {
 			qerr = scanner(rows)
 			qerr = scanner(rows)
 			return qerr
 			return qerr
 		}, q, args...)
 		}, q, args...)
@@ -217,29 +270,49 @@ func (s statement) Close() error {
 }
 }
 
 
 func (s statement) Exec(args ...interface{}) (sql.Result, error) {
 func (s statement) Exec(args ...interface{}) (sql.Result, error) {
-	return execStmt(s.stmt, s.query, args...)
+	return s.ExecCtx(context.Background(), args...)
+}
+
+func (s statement) ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error) {
+	return execStmt(ctx, s.stmt, s.query, args...)
 }
 }
 
 
 func (s statement) QueryRow(v interface{}, args ...interface{}) error {
 func (s statement) QueryRow(v interface{}, args ...interface{}) error {
-	return queryStmt(s.stmt, func(rows *sql.Rows) error {
+	return s.QueryRowCtx(context.Background(), v, args...)
+}
+
+func (s statement) QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error {
+	return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
 		return unmarshalRow(v, rows, true)
 		return unmarshalRow(v, rows, true)
 	}, s.query, args...)
 	}, s.query, args...)
 }
 }
 
 
 func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
 func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
-	return queryStmt(s.stmt, func(rows *sql.Rows) error {
+	return s.QueryRowPartialCtx(context.Background(), v, args...)
+}
+
+func (s statement) QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
+	return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
 		return unmarshalRow(v, rows, false)
 		return unmarshalRow(v, rows, false)
 	}, s.query, args...)
 	}, s.query, args...)
 }
 }
 
 
 func (s statement) QueryRows(v interface{}, args ...interface{}) error {
 func (s statement) QueryRows(v interface{}, args ...interface{}) error {
-	return queryStmt(s.stmt, func(rows *sql.Rows) error {
+	return s.QueryRowsCtx(context.Background(), v, args...)
+}
+
+func (s statement) QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error {
+	return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
 		return unmarshalRows(v, rows, true)
 		return unmarshalRows(v, rows, true)
 	}, s.query, args...)
 	}, s.query, args...)
 }
 }
 
 
 func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
 func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
-	return queryStmt(s.stmt, func(rows *sql.Rows) error {
+	return s.QueryRowsPartialCtx(context.Background(), v, args...)
+}
+
+func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
+	return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
 		return unmarshalRows(v, rows, false)
 		return unmarshalRows(v, rows, false)
 	}, s.query, args...)
 	}, s.query, args...)
 }
 }

+ 23 - 20
core/stores/sqlx/stmt.go

@@ -1,6 +1,7 @@
 package sqlx
 package sqlx
 
 
 import (
 import (
+	"context"
 	"database/sql"
 	"database/sql"
 	"time"
 	"time"
 
 
@@ -18,64 +19,65 @@ func SetSlowThreshold(threshold time.Duration) {
 	slowThreshold.Set(threshold)
 	slowThreshold.Set(threshold)
 }
 }
 
 
-func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
+func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
 	stmt, err := format(q, args...)
 	stmt, err := format(q, args...)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
 	startTime := timex.Now()
 	startTime := timex.Now()
-	result, err := conn.Exec(q, args...)
+	result, err := conn.ExecContext(ctx, q, args...)
 	duration := timex.Since(startTime)
 	duration := timex.Since(startTime)
 	if duration > slowThreshold.Load() {
 	if duration > slowThreshold.Load() {
-		logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
+		logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
 	} else {
 	} else {
-		logx.WithDuration(duration).Infof("sql exec: %s", stmt)
+		logx.WithContext(ctx).WithDuration(duration).Infof("sql exec: %s", stmt)
 	}
 	}
 	if err != nil {
 	if err != nil {
-		logSqlError(stmt, err)
+		logSqlError(ctx, stmt, err)
 	}
 	}
 
 
 	return result, err
 	return result, err
 }
 }
 
 
-func execStmt(conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
+func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
 	stmt, err := format(q, args...)
 	stmt, err := format(q, args...)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
 	startTime := timex.Now()
 	startTime := timex.Now()
-	result, err := conn.Exec(args...)
+	result, err := conn.ExecContext(ctx, args...)
 	duration := timex.Since(startTime)
 	duration := timex.Since(startTime)
 	if duration > slowThreshold.Load() {
 	if duration > slowThreshold.Load() {
-		logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
+		logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
 	} else {
 	} else {
-		logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
+		logx.WithContext(ctx).WithDuration(duration).Infof("sql execStmt: %s", stmt)
 	}
 	}
 	if err != nil {
 	if err != nil {
-		logSqlError(stmt, err)
+		logSqlError(ctx, stmt, err)
 	}
 	}
 
 
 	return result, err
 	return result, err
 }
 }
 
 
-func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
+func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error,
+	q string, args ...interface{}) error {
 	stmt, err := format(q, args...)
 	stmt, err := format(q, args...)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
 	startTime := timex.Now()
 	startTime := timex.Now()
-	rows, err := conn.Query(q, args...)
+	rows, err := conn.QueryContext(ctx, q, args...)
 	duration := timex.Since(startTime)
 	duration := timex.Since(startTime)
 	if duration > slowThreshold.Load() {
 	if duration > slowThreshold.Load() {
-		logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
+		logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
 	} else {
 	} else {
-		logx.WithDuration(duration).Infof("sql query: %s", stmt)
+		logx.WithContext(ctx).WithDuration(duration).Infof("sql query: %s", stmt)
 	}
 	}
 	if err != nil {
 	if err != nil {
-		logSqlError(stmt, err)
+		logSqlError(ctx, stmt, err)
 		return err
 		return err
 	}
 	}
 	defer rows.Close()
 	defer rows.Close()
@@ -83,22 +85,23 @@ func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...in
 	return scanner(rows)
 	return scanner(rows)
 }
 }
 
 
-func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
+func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error,
+	q string, args ...interface{}) error {
 	stmt, err := format(q, args...)
 	stmt, err := format(q, args...)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
 	startTime := timex.Now()
 	startTime := timex.Now()
-	rows, err := conn.Query(args...)
+	rows, err := conn.QueryContext(ctx, args...)
 	duration := timex.Since(startTime)
 	duration := timex.Since(startTime)
 	if duration > slowThreshold.Load() {
 	if duration > slowThreshold.Load() {
-		logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
+		logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
 	} else {
 	} else {
-		logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
+		logx.WithContext(ctx).WithDuration(duration).Infof("sql queryStmt: %s", stmt)
 	}
 	}
 	if err != nil {
 	if err != nil {
-		logSqlError(stmt, err)
+		logSqlError(ctx, stmt, err)
 		return err
 		return err
 	}
 	}
 	defer rows.Close()
 	defer rows.Close()

+ 21 - 4
core/stores/sqlx/stmt_test.go

@@ -1,6 +1,7 @@
 package sqlx
 package sqlx
 
 
 import (
 import (
+	"context"
 	"database/sql"
 	"database/sql"
 	"errors"
 	"errors"
 	"testing"
 	"testing"
@@ -57,7 +58,7 @@ func TestStmt_exec(t *testing.T) {
 		test := test
 		test := test
 		fns := []func(args ...interface{}) (sql.Result, error){
 		fns := []func(args ...interface{}) (sql.Result, error){
 			func(args ...interface{}) (sql.Result, error) {
 			func(args ...interface{}) (sql.Result, error) {
-				return exec(&mockedSessionConn{
+				return exec(context.Background(), &mockedSessionConn{
 					lastInsertId: test.lastInsertId,
 					lastInsertId: test.lastInsertId,
 					rowsAffected: test.rowsAffected,
 					rowsAffected: test.rowsAffected,
 					err:          test.err,
 					err:          test.err,
@@ -65,7 +66,7 @@ func TestStmt_exec(t *testing.T) {
 				}, test.query, args...)
 				}, test.query, args...)
 			},
 			},
 			func(args ...interface{}) (sql.Result, error) {
 			func(args ...interface{}) (sql.Result, error) {
-				return execStmt(&mockedStmtConn{
+				return execStmt(context.Background(), &mockedStmtConn{
 					lastInsertId: test.lastInsertId,
 					lastInsertId: test.lastInsertId,
 					rowsAffected: test.rowsAffected,
 					rowsAffected: test.rowsAffected,
 					err:          test.err,
 					err:          test.err,
@@ -137,7 +138,7 @@ func TestStmt_query(t *testing.T) {
 		test := test
 		test := test
 		fns := []func(args ...interface{}) error{
 		fns := []func(args ...interface{}) error{
 			func(args ...interface{}) error {
 			func(args ...interface{}) error {
-				return query(&mockedSessionConn{
+				return query(context.Background(), &mockedSessionConn{
 					err:   test.err,
 					err:   test.err,
 					delay: test.delay,
 					delay: test.delay,
 				}, func(rows *sql.Rows) error {
 				}, func(rows *sql.Rows) error {
@@ -145,7 +146,7 @@ func TestStmt_query(t *testing.T) {
 				}, test.query, args...)
 				}, test.query, args...)
 			},
 			},
 			func(args ...interface{}) error {
 			func(args ...interface{}) error {
-				return queryStmt(&mockedStmtConn{
+				return queryStmt(context.Background(), &mockedStmtConn{
 					err:   test.err,
 					err:   test.err,
 					delay: test.delay,
 					delay: test.delay,
 				}, func(rows *sql.Rows) error {
 				}, func(rows *sql.Rows) error {
@@ -185,6 +186,10 @@ type mockedSessionConn struct {
 }
 }
 
 
 func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
 func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
+	return m.ExecContext(context.Background(), query, args...)
+}
+
+func (m *mockedSessionConn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
 	if m.delay {
 	if m.delay {
 		time.Sleep(defaultSlowThreshold + time.Millisecond)
 		time.Sleep(defaultSlowThreshold + time.Millisecond)
 	}
 	}
@@ -195,6 +200,10 @@ func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result,
 }
 }
 
 
 func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
 func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
+	return m.QueryContext(context.Background(), query, args...)
+}
+
+func (m *mockedSessionConn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
 	if m.delay {
 	if m.delay {
 		time.Sleep(defaultSlowThreshold + time.Millisecond)
 		time.Sleep(defaultSlowThreshold + time.Millisecond)
 	}
 	}
@@ -214,6 +223,10 @@ type mockedStmtConn struct {
 }
 }
 
 
 func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
 func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
+	return m.ExecContext(context.Background(), args...)
+}
+
+func (m *mockedStmtConn) ExecContext(_ context.Context, _ ...interface{}) (sql.Result, error) {
 	if m.delay {
 	if m.delay {
 		time.Sleep(defaultSlowThreshold + time.Millisecond)
 		time.Sleep(defaultSlowThreshold + time.Millisecond)
 	}
 	}
@@ -224,6 +237,10 @@ func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
 }
 }
 
 
 func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
 func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
+	return m.QueryContext(context.Background(), args...)
+}
+
+func (m *mockedStmtConn) QueryContext(_ context.Context, _ ...interface{}) (*sql.Rows, error) {
 	if m.delay {
 	if m.delay {
 		time.Sleep(defaultSlowThreshold + time.Millisecond)
 		time.Sleep(defaultSlowThreshold + time.Millisecond)
 	}
 	}

+ 41 - 12
core/stores/sqlx/tx.go

@@ -1,6 +1,7 @@
 package sqlx
 package sqlx
 
 
 import (
 import (
+	"context"
 	"database/sql"
 	"database/sql"
 	"fmt"
 	"fmt"
 )
 )
@@ -26,11 +27,19 @@ func NewSessionFromTx(tx *sql.Tx) Session {
 }
 }
 
 
 func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
 func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
-	return exec(t.Tx, q, args...)
+	return t.ExecCtx(context.Background(), q, args...)
+}
+
+func (t txSession) ExecCtx(ctx context.Context, q string, args ...interface{}) (sql.Result, error) {
+	return exec(ctx, t.Tx, q, args...)
 }
 }
 
 
 func (t txSession) Prepare(q string) (StmtSession, error) {
 func (t txSession) Prepare(q string) (StmtSession, error) {
-	stmt, err := t.Tx.Prepare(q)
+	return t.PrepareCtx(context.Background(), q)
+}
+
+func (t txSession) PrepareCtx(ctx context.Context, q string) (StmtSession, error) {
+	stmt, err := t.Tx.PrepareContext(ctx, q)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -42,25 +51,43 @@ func (t txSession) Prepare(q string) (StmtSession, error) {
 }
 }
 
 
 func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
 func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
-	return query(t.Tx, func(rows *sql.Rows) error {
+	return t.QueryRowCtx(context.Background(), v, q, args...)
+}
+
+func (t txSession) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
+	return query(ctx, t.Tx, func(rows *sql.Rows) error {
 		return unmarshalRow(v, rows, true)
 		return unmarshalRow(v, rows, true)
 	}, q, args...)
 	}, q, args...)
 }
 }
 
 
 func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
 func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
-	return query(t.Tx, func(rows *sql.Rows) error {
+	return t.QueryRowPartialCtx(context.Background(), v, q, args...)
+}
+
+func (t txSession) QueryRowPartialCtx(ctx context.Context, v interface{}, q string,
+	args ...interface{}) error {
+	return query(ctx, t.Tx, func(rows *sql.Rows) error {
 		return unmarshalRow(v, rows, false)
 		return unmarshalRow(v, rows, false)
 	}, q, args...)
 	}, q, args...)
 }
 }
 
 
 func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
 func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
-	return query(t.Tx, func(rows *sql.Rows) error {
+	return t.QueryRowsCtx(context.Background(), v, q, args...)
+}
+
+func (t txSession) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
+	return query(ctx, t.Tx, func(rows *sql.Rows) error {
 		return unmarshalRows(v, rows, true)
 		return unmarshalRows(v, rows, true)
 	}, q, args...)
 	}, q, args...)
 }
 }
 
 
 func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
 func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
-	return query(t.Tx, func(rows *sql.Rows) error {
+	return t.QueryRowsPartialCtx(context.Background(), v, q, args...)
+}
+
+func (t txSession) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string,
+	args ...interface{}) error {
+	return query(ctx, t.Tx, func(rows *sql.Rows) error {
 		return unmarshalRows(v, rows, false)
 		return unmarshalRows(v, rows, false)
 	}, q, args...)
 	}, q, args...)
 }
 }
@@ -76,17 +103,19 @@ func begin(db *sql.DB) (trans, error) {
 	}, nil
 	}, nil
 }
 }
 
 
-func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
+func transact(ctx context.Context, db *commonSqlConn, b beginnable,
+	fn func(context.Context, Session) error) (err error) {
 	conn, err := db.connProv()
 	conn, err := db.connProv()
 	if err != nil {
 	if err != nil {
 		db.onError(err)
 		db.onError(err)
 		return err
 		return err
 	}
 	}
 
 
-	return transactOnConn(conn, b, fn)
+	return transactOnConn(ctx, conn, b, fn)
 }
 }
 
 
-func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
+func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable,
+	fn func(context.Context, Session) error) (err error) {
 	var tx trans
 	var tx trans
 	tx, err = b(conn)
 	tx, err = b(conn)
 	if err != nil {
 	if err != nil {
@@ -96,18 +125,18 @@ func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err err
 	defer func() {
 	defer func() {
 		if p := recover(); p != nil {
 		if p := recover(); p != nil {
 			if e := tx.Rollback(); e != nil {
 			if e := tx.Rollback(); e != nil {
-				err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
+				err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e)
 			} else {
 			} else {
 				err = fmt.Errorf("recoveer from %#v", p)
 				err = fmt.Errorf("recoveer from %#v", p)
 			}
 			}
 		} else if err != nil {
 		} else if err != nil {
 			if e := tx.Rollback(); e != nil {
 			if e := tx.Rollback(); e != nil {
-				err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
+				err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e)
 			}
 			}
 		} else {
 		} else {
 			err = tx.Commit()
 			err = tx.Commit()
 		}
 		}
 	}()
 	}()
 
 
-	return fn(tx)
+	return fn(ctx, tx)
 }
 }

+ 33 - 6
core/stores/sqlx/tx_test.go

@@ -1,6 +1,7 @@
 package sqlx
 package sqlx
 
 
 import (
 import (
+	"context"
 	"database/sql"
 	"database/sql"
 	"errors"
 	"errors"
 	"testing"
 	"testing"
@@ -26,26 +27,50 @@ func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) {
 	return nil, nil
 	return nil, nil
 }
 }
 
 
+func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
+	return nil, nil
+}
+
 func (mt *mockTx) Prepare(query string) (StmtSession, error) {
 func (mt *mockTx) Prepare(query string) (StmtSession, error) {
 	return nil, nil
 	return nil, nil
 }
 }
 
 
+func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
+	return nil, nil
+}
+
 func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
 func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
 	return nil
 	return nil
 }
 }
 
 
+func (mt *mockTx) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
+	return nil
+}
+
 func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
 func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
 	return nil
 	return nil
 }
 }
 
 
+func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
+	return nil
+}
+
 func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
 func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
 	return nil
 	return nil
 }
 }
 
 
+func (mt *mockTx) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
+	return nil
+}
+
 func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
 func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
 	return nil
 	return nil
 }
 }
 
 
+func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
+	return nil
+}
+
 func (mt *mockTx) Rollback() error {
 func (mt *mockTx) Rollback() error {
 	mt.status |= mockRollback
 	mt.status |= mockRollback
 	return nil
 	return nil
@@ -59,18 +84,20 @@ func beginMock(mock *mockTx) beginnable {
 
 
 func TestTransactCommit(t *testing.T) {
 func TestTransactCommit(t *testing.T) {
 	mock := &mockTx{}
 	mock := &mockTx{}
-	err := transactOnConn(nil, beginMock(mock), func(Session) error {
-		return nil
-	})
+	err := transactOnConn(context.Background(), nil, beginMock(mock),
+		func(context.Context, Session) error {
+			return nil
+		})
 	assert.Equal(t, mockCommit, mock.status)
 	assert.Equal(t, mockCommit, mock.status)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 }
 }
 
 
 func TestTransactRollback(t *testing.T) {
 func TestTransactRollback(t *testing.T) {
 	mock := &mockTx{}
 	mock := &mockTx{}
-	err := transactOnConn(nil, beginMock(mock), func(Session) error {
-		return errors.New("rollback")
-	})
+	err := transactOnConn(context.Background(), nil, beginMock(mock),
+		func(context.Context, Session) error {
+			return errors.New("rollback")
+		})
 	assert.Equal(t, mockRollback, mock.status)
 	assert.Equal(t, mockRollback, mock.status)
 	assert.NotNil(t, err)
 	assert.NotNil(t, err)
 }
 }

+ 3 - 2
core/stores/sqlx/utils.go

@@ -1,6 +1,7 @@
 package sqlx
 package sqlx
 
 
 import (
 import (
+	"context"
 	"fmt"
 	"fmt"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
@@ -109,9 +110,9 @@ func logInstanceError(datasource string, err error) {
 	logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
 	logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
 }
 }
 
 
-func logSqlError(stmt string, err error) {
+func logSqlError(ctx context.Context, stmt string, err error) {
 	if err != nil && err != ErrNotFound {
 	if err != nil && err != ErrNotFound {
-		logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
+		logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
 	}
 	}
 }
 }
 
 

+ 2 - 2
tools/goctl/api/docgen/gen.go

@@ -12,7 +12,7 @@ import (
 	"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
 	"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
 )
 )
 
 
-// DocCommand generate markdown doc file
+// DocCommand generate Markdown doc file
 func DocCommand(c *cli.Context) error {
 func DocCommand(c *cli.Context) error {
 	dir := c.String("dir")
 	dir := c.String("dir")
 	if len(dir) == 0 {
 	if len(dir) == 0 {
@@ -45,7 +45,7 @@ func DocCommand(c *cli.Context) error {
 	for _, p := range files {
 	for _, p := range files {
 		api, err := parser.Parse(p)
 		api, err := parser.Parse(p)
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("parse file: %s, err: %s", p, err.Error())
+			return fmt.Errorf("parse file: %s, err: %w", p, err)
 		}
 		}
 
 
 		api.Service = api.Service.JoinPrefix()
 		api.Service = api.Service.JoinPrefix()

+ 2 - 2
tools/goctl/migrate/migrate.go

@@ -164,12 +164,12 @@ func writeFile(pkgs []*ast.Package, verbose bool) error {
 			w := bytes.NewBuffer(nil)
 			w := bytes.NewBuffer(nil)
 			err := format.Node(w, fset, file)
 			err := format.Node(w, fset, file)
 			if err != nil {
 			if err != nil {
-				return fmt.Errorf("[rewriteImport] format file %s error: %+v", filename, err)
+				return fmt.Errorf("[rewriteImport] format file %s error: %w", filename, err)
 			}
 			}
 
 
 			err = ioutil.WriteFile(filename, w.Bytes(), os.ModePerm)
 			err = ioutil.WriteFile(filename, w.Bytes(), os.ModePerm)
 			if err != nil {
 			if err != nil {
-				return fmt.Errorf("[rewriteImport] write file %s error: %+v", filename, err)
+				return fmt.Errorf("[rewriteImport] write file %s error: %w", filename, err)
 			}
 			}
 			if verbose {
 			if verbose {
 				console.Success("[OK] migrated %q successfully", filepath.Base(filename))
 				console.Success("[OK] migrated %q successfully", filepath.Base(filename))