浏览代码

print entire sql statements in logx if necessary (#704)

Kevin Wan 4 年之前
父节点
当前提交
aaa39e17a3
共有 5 个文件被更改,包括 153 次插入67 次删除
  1. 9 7
      core/stores/sqlx/sqlconn.go
  2. 22 7
      core/stores/sqlx/stmt.go
  3. 29 7
      core/stores/sqlx/stmt_test.go
  4. 50 36
      core/stores/sqlx/utils.go
  5. 43 10
      core/stores/sqlx/utils_test.go

+ 9 - 7
core/stores/sqlx/sqlconn.go

@@ -56,7 +56,8 @@ type (
 	}
 	}
 
 
 	statement struct {
 	statement struct {
-		stmt *sql.Stmt
+		query string
+		stmt  *sql.Stmt
 	}
 	}
 
 
 	stmtConn interface {
 	stmtConn interface {
@@ -111,7 +112,8 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
 		}
 		}
 
 
 		stmt = statement{
 		stmt = statement{
-			stmt: st,
+			query: query,
+			stmt:  st,
 		}
 		}
 		return nil
 		return nil
 	}, db.acceptable)
 	}, db.acceptable)
@@ -181,29 +183,29 @@ func (s statement) Close() error {
 }
 }
 
 
 func (s statement) Exec(args ...interface{}) (sql.Result, error) {
 func (s statement) Exec(args ...interface{}) (sql.Result, error) {
-	return execStmt(s.stmt, args...)
+	return execStmt(s.stmt, s.query, args...)
 }
 }
 
 
 func (s statement) QueryRow(v interface{}, args ...interface{}) error {
 func (s statement) QueryRow(v interface{}, args ...interface{}) error {
 	return queryStmt(s.stmt, func(rows *sql.Rows) error {
 	return queryStmt(s.stmt, func(rows *sql.Rows) error {
 		return unmarshalRow(v, rows, true)
 		return unmarshalRow(v, rows, true)
-	}, args...)
+	}, s.query, args...)
 }
 }
 
 
 func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
 func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
 	return queryStmt(s.stmt, func(rows *sql.Rows) error {
 	return queryStmt(s.stmt, func(rows *sql.Rows) error {
 		return unmarshalRow(v, rows, false)
 		return unmarshalRow(v, rows, false)
-	}, args...)
+	}, s.query, args...)
 }
 }
 
 
 func (s statement) QueryRows(v interface{}, args ...interface{}) error {
 func (s statement) QueryRows(v interface{}, args ...interface{}) error {
 	return queryStmt(s.stmt, func(rows *sql.Rows) error {
 	return queryStmt(s.stmt, func(rows *sql.Rows) error {
 		return unmarshalRows(v, rows, true)
 		return unmarshalRows(v, rows, true)
-	}, args...)
+	}, s.query, args...)
 }
 }
 
 
 func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
 func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
 	return queryStmt(s.stmt, func(rows *sql.Rows) error {
 	return queryStmt(s.stmt, func(rows *sql.Rows) error {
 		return unmarshalRows(v, rows, false)
 		return unmarshalRows(v, rows, false)
-	}, args...)
+	}, s.query, args...)
 }
 }

+ 22 - 7
core/stores/sqlx/stmt.go

@@ -2,7 +2,6 @@ package sqlx
 
 
 import (
 import (
 	"database/sql"
 	"database/sql"
-	"fmt"
 	"time"
 	"time"
 
 
 	"github.com/tal-tech/go-zero/core/logx"
 	"github.com/tal-tech/go-zero/core/logx"
@@ -12,10 +11,14 @@ import (
 const slowThreshold = time.Millisecond * 500
 const slowThreshold = time.Millisecond * 500
 
 
 func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
 func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
+	stmt, err := format(q, args...)
+	if err != nil {
+		return nil, err
+	}
+
 	startTime := timex.Now()
 	startTime := timex.Now()
 	result, err := conn.Exec(q, args...)
 	result, err := conn.Exec(q, args...)
 	duration := timex.Since(startTime)
 	duration := timex.Since(startTime)
-	stmt := formatForPrint(q, args)
 	if duration > slowThreshold {
 	if duration > slowThreshold {
 		logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
 		logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
 	} else {
 	} else {
@@ -28,11 +31,15 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
 	return result, err
 	return result, err
 }
 }
 
 
-func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
+func execStmt(conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
+	stmt, err := format(q, args...)
+	if err != nil {
+		return nil, err
+	}
+
 	startTime := timex.Now()
 	startTime := timex.Now()
 	result, err := conn.Exec(args...)
 	result, err := conn.Exec(args...)
 	duration := timex.Since(startTime)
 	duration := timex.Since(startTime)
-	stmt := fmt.Sprint(args...)
 	if duration > slowThreshold {
 	if duration > slowThreshold {
 		logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
 		logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
 	} else {
 	} else {
@@ -46,10 +53,14 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
 }
 }
 
 
 func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
 func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
+	stmt, err := format(q, args...)
+	if err != nil {
+		return err
+	}
+
 	startTime := timex.Now()
 	startTime := timex.Now()
 	rows, err := conn.Query(q, args...)
 	rows, err := conn.Query(q, args...)
 	duration := timex.Since(startTime)
 	duration := timex.Since(startTime)
-	stmt := fmt.Sprint(args...)
 	if duration > slowThreshold {
 	if duration > slowThreshold {
 		logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
 		logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
 	} else {
 	} else {
@@ -64,8 +75,12 @@ func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...in
 	return scanner(rows)
 	return scanner(rows)
 }
 }
 
 
-func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, args ...interface{}) error {
-	stmt := fmt.Sprint(args...)
+func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
+	stmt, err := format(q, args...)
+	if err != nil {
+		return err
+	}
+
 	startTime := timex.Now()
 	startTime := timex.Now()
 	rows, err := conn.Query(args...)
 	rows, err := conn.Query(args...)
 	duration := timex.Since(startTime)
 	duration := timex.Since(startTime)

+ 29 - 7
core/stores/sqlx/stmt_test.go

@@ -14,6 +14,7 @@ var errMockedPlaceholder = errors.New("placeholder")
 func TestStmt_exec(t *testing.T) {
 func TestStmt_exec(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		name         string
 		name         string
+		query        string
 		args         []interface{}
 		args         []interface{}
 		delay        bool
 		delay        bool
 		hasError     bool
 		hasError     bool
@@ -23,18 +24,28 @@ func TestStmt_exec(t *testing.T) {
 	}{
 	}{
 		{
 		{
 			name:         "normal",
 			name:         "normal",
+			query:        "select user from users where id=?",
 			args:         []interface{}{1},
 			args:         []interface{}{1},
 			lastInsertId: 1,
 			lastInsertId: 1,
 			rowsAffected: 2,
 			rowsAffected: 2,
 		},
 		},
 		{
 		{
 			name:     "exec error",
 			name:     "exec error",
+			query:    "select user from users where id=?",
+			args:     []interface{}{1},
+			hasError: true,
+			err:      errors.New("exec"),
+		},
+		{
+			name:     "exec more args error",
+			query:    "select user from users where id=? and name=?",
 			args:     []interface{}{1},
 			args:     []interface{}{1},
 			hasError: true,
 			hasError: true,
 			err:      errors.New("exec"),
 			err:      errors.New("exec"),
 		},
 		},
 		{
 		{
 			name:         "slowcall",
 			name:         "slowcall",
+			query:        "select user from users where id=?",
 			args:         []interface{}{1},
 			args:         []interface{}{1},
 			delay:        true,
 			delay:        true,
 			lastInsertId: 1,
 			lastInsertId: 1,
@@ -51,7 +62,7 @@ func TestStmt_exec(t *testing.T) {
 					rowsAffected: test.rowsAffected,
 					rowsAffected: test.rowsAffected,
 					err:          test.err,
 					err:          test.err,
 					delay:        test.delay,
 					delay:        test.delay,
-				}, "select user from users where id=?", args...)
+				}, test.query, args...)
 			},
 			},
 			func(args ...interface{}) (sql.Result, error) {
 			func(args ...interface{}) (sql.Result, error) {
 				return execStmt(&mockedStmtConn{
 				return execStmt(&mockedStmtConn{
@@ -59,7 +70,7 @@ func TestStmt_exec(t *testing.T) {
 					rowsAffected: test.rowsAffected,
 					rowsAffected: test.rowsAffected,
 					err:          test.err,
 					err:          test.err,
 					delay:        test.delay,
 					delay:        test.delay,
-				}, args...)
+				}, test.query, args...)
 			},
 			},
 		}
 		}
 
 
@@ -89,23 +100,34 @@ func TestStmt_exec(t *testing.T) {
 func TestStmt_query(t *testing.T) {
 func TestStmt_query(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		name     string
 		name     string
+		query    string
 		args     []interface{}
 		args     []interface{}
 		delay    bool
 		delay    bool
 		hasError bool
 		hasError bool
 		err      error
 		err      error
 	}{
 	}{
 		{
 		{
-			name: "normal",
-			args: []interface{}{1},
+			name:  "normal",
+			query: "select user from users where id=?",
+			args:  []interface{}{1},
 		},
 		},
 		{
 		{
 			name:     "query error",
 			name:     "query error",
+			query:    "select user from users where id=?",
+			args:     []interface{}{1},
+			hasError: true,
+			err:      errors.New("exec"),
+		},
+		{
+			name:     "query more args error",
+			query:    "select user from users where id=? and name=?",
 			args:     []interface{}{1},
 			args:     []interface{}{1},
 			hasError: true,
 			hasError: true,
 			err:      errors.New("exec"),
 			err:      errors.New("exec"),
 		},
 		},
 		{
 		{
 			name:  "slowcall",
 			name:  "slowcall",
+			query: "select user from users where id=?",
 			args:  []interface{}{1},
 			args:  []interface{}{1},
 			delay: true,
 			delay: true,
 		},
 		},
@@ -120,7 +142,7 @@ func TestStmt_query(t *testing.T) {
 					delay: test.delay,
 					delay: test.delay,
 				}, func(rows *sql.Rows) error {
 				}, func(rows *sql.Rows) error {
 					return nil
 					return nil
-				}, "select user from users where id=?", args...)
+				}, test.query, args...)
 			},
 			},
 			func(args ...interface{}) error {
 			func(args ...interface{}) error {
 				return queryStmt(&mockedStmtConn{
 				return queryStmt(&mockedStmtConn{
@@ -128,7 +150,7 @@ func TestStmt_query(t *testing.T) {
 					delay: test.delay,
 					delay: test.delay,
 				}, func(rows *sql.Rows) error {
 				}, func(rows *sql.Rows) error {
 					return nil
 					return nil
-				}, args...)
+				}, test.query, args...)
 			},
 			},
 		}
 		}
 
 
@@ -143,7 +165,7 @@ func TestStmt_query(t *testing.T) {
 					return
 					return
 				}
 				}
 
 
-				assert.Equal(t, errMockedPlaceholder, err)
+				assert.NotNil(t, err)
 			})
 			})
 		}
 		}
 	}
 	}

+ 50 - 36
core/stores/sqlx/utils.go

@@ -2,6 +2,7 @@ package sqlx
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"strconv"
 	"strings"
 	"strings"
 
 
 	"github.com/tal-tech/go-zero/core/logx"
 	"github.com/tal-tech/go-zero/core/logx"
@@ -45,24 +46,6 @@ func escape(input string) string {
 	return b.String()
 	return b.String()
 }
 }
 
 
-func formatForPrint(query string, args ...interface{}) string {
-	if len(args) == 0 {
-		return query
-	}
-
-	var vals []string
-	for _, arg := range args {
-		vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
-	}
-
-	var b strings.Builder
-	b.WriteByte('[')
-	b.WriteString(strings.Join(vals, ", "))
-	b.WriteByte(']')
-
-	return strings.Join([]string{query, b.String()}, " ")
-}
-
 func format(query string, args ...interface{}) (string, error) {
 func format(query string, args ...interface{}) (string, error) {
 	numArgs := len(args)
 	numArgs := len(args)
 	if numArgs == 0 {
 	if numArgs == 0 {
@@ -72,36 +55,50 @@ func format(query string, args ...interface{}) (string, error) {
 	var b strings.Builder
 	var b strings.Builder
 	argIndex := 0
 	argIndex := 0
 
 
-	for _, ch := range query {
-		if ch == '?' {
+	bytes := len(query)
+	for i := 0; i < bytes; i++ {
+		ch := query[i]
+		switch ch {
+		case '?':
 			if argIndex >= numArgs {
 			if argIndex >= numArgs {
 				return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
 				return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
 			}
 			}
 
 
-			arg := args[argIndex]
+			writeValue(&b, args[argIndex])
 			argIndex++
 			argIndex++
+		case '$':
+			var j int
+			for j = i + 1; j < bytes; j++ {
+				char := query[j]
+				if char < '0' || '9' < char {
+					break
+				}
+			}
+			if j > i+1 {
+				index, err := strconv.Atoi(query[i+1 : j])
+				if err != nil {
+					return "", err
+				}
 
 
-			switch v := arg.(type) {
-			case bool:
-				if v {
-					b.WriteByte('1')
-				} else {
-					b.WriteByte('0')
+				// index starts from 1 for pg
+				if index > argIndex {
+					argIndex = index
+				}
+				index--
+				if index < 0 || numArgs <= index {
+					return "", fmt.Errorf("error: wrong index %d in sql", index)
 				}
 				}
-			case string:
-				b.WriteByte('\'')
-				b.WriteString(escape(v))
-				b.WriteByte('\'')
-			default:
-				b.WriteString(mapping.Repr(v))
+
+				writeValue(&b, args[index])
+				i = j - 1
 			}
 			}
-		} else {
-			b.WriteRune(ch)
+		default:
+			b.WriteByte(ch)
 		}
 		}
 	}
 	}
 
 
 	if argIndex < numArgs {
 	if argIndex < numArgs {
-		return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
+		return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
 	}
 	}
 
 
 	return b.String(), nil
 	return b.String(), nil
@@ -117,3 +114,20 @@ func logSqlError(stmt string, err error) {
 		logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
 		logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
 	}
 	}
 }
 }
+
+func writeValue(buf *strings.Builder, arg interface{}) {
+	switch v := arg.(type) {
+	case bool:
+		if v {
+			buf.WriteByte('1')
+		} else {
+			buf.WriteByte('0')
+		}
+	case string:
+		buf.WriteByte('\'')
+		buf.WriteString(escape(v))
+		buf.WriteByte('\'')
+	default:
+		buf.WriteString(mapping.Repr(v))
+	}
+}

+ 43 - 10
core/stores/sqlx/utils_test.go

@@ -29,30 +29,63 @@ func TestDesensitize_WithoutAccount(t *testing.T) {
 	assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
 	assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
 }
 }
 
 
-func TestFormatForPrint(t *testing.T) {
+func TestFormat(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		name   string
 		name   string
 		query  string
 		query  string
 		args   []interface{}
 		args   []interface{}
 		expect string
 		expect string
+		hasErr bool
 	}{
 	}{
 		{
 		{
-			name:   "no args",
-			query:  "select user, name from table where id=?",
-			expect: `select user, name from table where id=?`,
+			name:   "mysql normal",
+			query:  "select name, age from users where bool=? and phone=?",
+			args:   []interface{}{true, "133"},
+			expect: "select name, age from users where bool=1 and phone='133'",
 		},
 		},
 		{
 		{
-			name:   "one arg",
-			query:  "select user, name from table where id=?",
-			args:   []interface{}{"kevin"},
-			expect: `select user, name from table where id=? ["kevin"]`,
+			name:   "mysql normal",
+			query:  "select name, age from users where bool=? and phone=?",
+			args:   []interface{}{false, "133"},
+			expect: "select name, age from users where bool=0 and phone='133'",
+		},
+		{
+			name:   "pg normal",
+			query:  "select name, age from users where bool=$1 and phone=$2",
+			args:   []interface{}{true, "133"},
+			expect: "select name, age from users where bool=1 and phone='133'",
+		},
+		{
+			name:   "pg normal reverse",
+			query:  "select name, age from users where bool=$2 and phone=$1",
+			args:   []interface{}{"133", false},
+			expect: "select name, age from users where bool=0 and phone='133'",
+		},
+		{
+			name:   "pg error not number",
+			query:  "select name, age from users where bool=$a and phone=$1",
+			args:   []interface{}{"133", false},
+			hasErr: true,
+		},
+		{
+			name:   "pg error more args",
+			query:  "select name, age from users where bool=$2 and phone=$1 and nickname=$3",
+			args:   []interface{}{"133", false},
+			hasErr: true,
 		},
 		},
 	}
 	}
 
 
 	for _, test := range tests {
 	for _, test := range tests {
+		test := test
 		t.Run(test.name, func(t *testing.T) {
 		t.Run(test.name, func(t *testing.T) {
-			actual := formatForPrint(test.query, test.args...)
-			assert.Equal(t, test.expect, actual)
+			t.Parallel()
+
+			actual, err := format(test.query, test.args...)
+			if test.hasErr {
+				assert.NotNil(t, err)
+			} else {
+				assert.Equal(t, test.expect, actual)
+			}
 		})
 		})
 	}
 	}
 }
 }