Bläddra i källkod

expose sql.DB to let orm operate on it (#1015)

* expose sql.DB to let orm operate on it

* add missing RawDB methods

* add NewSqlConnFromDB for cooperate with dtm
Kevin Wan 3 år sedan
förälder
incheckning
f6d9e19ecb

+ 8 - 0
core/stores/sqlc/cachedsql_test.go

@@ -600,6 +600,10 @@ func (d dummySqlConn) QueryRowsPartial(v interface{}, query string, args ...inte
 	return nil
 }
 
+func (d dummySqlConn) RawDB() (*sql.DB, error) {
+	return nil, nil
+}
+
 func (d dummySqlConn) Transact(func(session sqlx.Session) error) error {
 	return nil
 }
@@ -621,6 +625,10 @@ func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}
 	return c.dummySqlConn.QueryRows(v, query, args...)
 }
 
+func (c *trackedConn) RawDB() (*sql.DB, error) {
+	return nil, nil
+}
+
 func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
 	c.transactValue = true
 	return c.dummySqlConn.Transact(fn)

+ 4 - 0
core/stores/sqlx/bulkinserter_test.go

@@ -43,6 +43,10 @@ func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...inter
 	panic("should not called")
 }
 
+func (c *mockedConn) RawDB() (*sql.DB, error) {
+	panic("should not called")
+}
+
 func (c *mockedConn) Transact(func(session Session) error) error {
 	panic("should not called")
 }

+ 37 - 7
core/stores/sqlx/sqlconn.go

@@ -6,6 +6,9 @@ import (
 	"github.com/tal-tech/go-zero/core/breaker"
 )
 
+// datasource placeholder for logging error.
+const rawDB = "sql.DB"
+
 // ErrNotFound is an alias of sql.ErrNoRows
 var ErrNotFound = sql.ErrNoRows
 
@@ -23,6 +26,7 @@ type (
 	// SqlConn only stands for raw connections, so Transact method can be called.
 	SqlConn interface {
 		Session
+		RawDB() (*sql.DB, error)
 		Transact(func(session Session) error) error
 	}
 
@@ -43,13 +47,15 @@ type (
 	// Because CORBA doesn't support PREPARE, so we need to combine the
 	// query arguments into one string and do underlying query without arguments
 	commonSqlConn struct {
-		driverName string
 		datasource string
+		connProv   connProvider
 		beginTx    beginnable
 		brk        breaker.Breaker
 		accept     func(error) bool
 	}
 
+	connProvider func() (*sql.DB, error)
+
 	sessionConn interface {
 		Exec(query string, args ...interface{}) (sql.Result, error)
 		Query(query string, args ...interface{}) (*sql.Rows, error)
@@ -69,10 +75,30 @@ type (
 // NewSqlConn returns a SqlConn with given driver name and datasource.
 func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
 	conn := &commonSqlConn{
-		driverName: driverName,
 		datasource: datasource,
-		beginTx:    begin,
-		brk:        breaker.NewBreaker(),
+		connProv: func() (*sql.DB, error) {
+			return getSqlConn(driverName, datasource)
+		},
+		beginTx: begin,
+		brk:     breaker.NewBreaker(),
+	}
+	for _, opt := range opts {
+		opt(conn)
+	}
+
+	return conn
+}
+
+// NewSqlConnFromDB returns a SqlConn with the given sql.DB.
+// Use it with caution, it's provided for other ORM to interact with.
+func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
+	conn := &commonSqlConn{
+		datasource: rawDB,
+		connProv: func() (*sql.DB, error) {
+			return db, nil
+		},
+		beginTx: begin,
+		brk:     breaker.NewBreaker(),
 	}
 	for _, opt := range opts {
 		opt(conn)
@@ -84,7 +110,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
 func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
 	err = db.brk.DoWithAcceptable(func() error {
 		var conn *sql.DB
-		conn, err = getSqlConn(db.driverName, db.datasource)
+		conn, err = db.connProv()
 		if err != nil {
 			logInstanceError(db.datasource, err)
 			return err
@@ -100,7 +126,7 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result,
 func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
 	err = db.brk.DoWithAcceptable(func() error {
 		var conn *sql.DB
-		conn, err = getSqlConn(db.driverName, db.datasource)
+		conn, err = db.connProv()
 		if err != nil {
 			logInstanceError(db.datasource, err)
 			return err
@@ -145,6 +171,10 @@ func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...inter
 	}, q, args...)
 }
 
+func (db *commonSqlConn) RawDB() (*sql.DB, error) {
+	return db.connProv()
+}
+
 func (db *commonSqlConn) Transact(fn func(Session) error) error {
 	return db.brk.DoWithAcceptable(func() error {
 		return transact(db, db.beginTx, fn)
@@ -163,7 +193,7 @@ func (db *commonSqlConn) acceptable(err error) bool {
 func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
 	var qerr error
 	return db.brk.DoWithAcceptable(func() error {
-		conn, err := getSqlConn(db.driverName, db.datasource)
+		conn, err := db.connProv()
 		if err != nil {
 			logInstanceError(db.datasource, err)
 			return err

+ 5 - 2
core/stores/sqlx/sqlconn_test.go

@@ -21,12 +21,15 @@ func TestSqlConn(t *testing.T) {
 	mock.ExpectExec("any")
 	mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
 	conn := NewMysql(mockedDatasource)
+	db, err := conn.RawDB()
+	assert.Nil(t, err)
+	rawConn := NewSqlConnFromDB(db, withMysqlAcceptable())
 	badConn := NewMysql("badsql")
-	_, err := conn.Exec("any", "value")
+	_, err = conn.Exec("any", "value")
 	assert.NotNil(t, err)
 	_, err = badConn.Exec("any", "value")
 	assert.NotNil(t, err)
-	_, err = conn.Prepare("any")
+	_, err = rawConn.Prepare("any")
 	assert.NotNil(t, err)
 	_, err = badConn.Prepare("any")
 	assert.NotNil(t, err)

+ 1 - 1
core/stores/sqlx/tx.go

@@ -71,7 +71,7 @@ func begin(db *sql.DB) (trans, error) {
 }
 
 func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
-	conn, err := getSqlConn(db.driverName, db.datasource)
+	conn, err := db.connProv()
 	if err != nil {
 		logInstanceError(db.datasource, err)
 		return err

+ 6 - 0
tools/goctl/model/sql/test/sqlconn.go

@@ -13,6 +13,7 @@ type (
 	MockConn struct {
 		db *sql.DB
 	}
+
 	statement struct {
 		stmt *sql.Stmt
 	}
@@ -62,6 +63,11 @@ func (conn *MockConn) QueryRowsPartial(v interface{}, q string, args ...interfac
 	}, q, args...)
 }
 
+// RawDB returns the underlying sql.DB.
+func (conn *MockConn) RawDB() (*sql.DB, error) {
+	return conn.db, nil
+}
+
 // Transact is the implemention of sqlx.SqlConn, nothing to do
 func (conn *MockConn) Transact(func(session sqlx.Session) error) error {
 	return nil