Эх сурвалжийг харах

support postgresql (#583)

support postgresql
Kevin Wan 4 жил өмнө
parent
commit
bd623aaac3

+ 1 - 0
core/stores/sqlx/bulkinserter.go

@@ -24,6 +24,7 @@ type (
 	ResultHandler func(sql.Result, error)
 
 	// A BulkInserter is used to batch insert records.
+	// Postgresql is not supported yet, because of the sql is formated with symbol `$`.
 	BulkInserter struct {
 		executor *executors.PeriodicalExecutor
 		inserter *dbInserter

+ 3 - 11
core/stores/sqlx/stmt.go

@@ -12,14 +12,10 @@ import (
 const slowThreshold = time.Millisecond * 500
 
 func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
-	stmt, err := format(q, args...)
-	if err != nil {
-		return nil, err
-	}
-
 	startTime := timex.Now()
 	result, err := conn.Exec(q, args...)
 	duration := timex.Since(startTime)
+	stmt := formatForPrint(q, args)
 	if duration > slowThreshold {
 		logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
 	} else {
@@ -33,10 +29,10 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
 }
 
 func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
-	stmt := fmt.Sprint(args...)
 	startTime := timex.Now()
 	result, err := conn.Exec(args...)
 	duration := timex.Since(startTime)
+	stmt := fmt.Sprint(args...)
 	if duration > slowThreshold {
 		logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
 	} else {
@@ -50,14 +46,10 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
 }
 
 func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
-	stmt, err := format(q, args...)
-	if err != nil {
-		return err
-	}
-
 	startTime := timex.Now()
 	rows, err := conn.Query(q, args...)
 	duration := timex.Since(startTime)
+	stmt := fmt.Sprint(args...)
 	if duration > slowThreshold {
 		logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
 	} else {

+ 9 - 33
core/stores/sqlx/stmt_test.go

@@ -16,7 +16,6 @@ func TestStmt_exec(t *testing.T) {
 		name         string
 		args         []interface{}
 		delay        bool
-		formatError  bool
 		hasError     bool
 		err          error
 		lastInsertId int64
@@ -28,12 +27,6 @@ func TestStmt_exec(t *testing.T) {
 			lastInsertId: 1,
 			rowsAffected: 2,
 		},
-		{
-			name:        "wrong format",
-			args:        []interface{}{1, 2},
-			formatError: true,
-			hasError:    true,
-		},
 		{
 			name:     "exec error",
 			args:     []interface{}{1},
@@ -70,18 +63,13 @@ func TestStmt_exec(t *testing.T) {
 			},
 		}
 
-		for i, fn := range fns {
-			i := i
+		for _, fn := range fns {
 			fn := fn
 			t.Run(test.name, func(t *testing.T) {
 				t.Parallel()
 
 				res, err := fn(test.args...)
-				if i == 0 && test.formatError {
-					assert.NotNil(t, err)
-					return
-				}
-				if !test.formatError && test.hasError {
+				if test.hasError {
 					assert.NotNil(t, err)
 					return
 				}
@@ -100,23 +88,16 @@ func TestStmt_exec(t *testing.T) {
 
 func TestStmt_query(t *testing.T) {
 	tests := []struct {
-		name        string
-		args        []interface{}
-		delay       bool
-		formatError bool
-		hasError    bool
-		err         error
+		name     string
+		args     []interface{}
+		delay    bool
+		hasError bool
+		err      error
 	}{
 		{
 			name: "normal",
 			args: []interface{}{1},
 		},
-		{
-			name:        "wrong format",
-			args:        []interface{}{1, 2},
-			formatError: true,
-			hasError:    true,
-		},
 		{
 			name:     "query error",
 			args:     []interface{}{1},
@@ -151,18 +132,13 @@ func TestStmt_query(t *testing.T) {
 			},
 		}
 
-		for i, fn := range fns {
-			i := i
+		for _, fn := range fns {
 			fn := fn
 			t.Run(test.name, func(t *testing.T) {
 				t.Parallel()
 
 				err := fn(test.args...)
-				if i == 0 && test.formatError {
-					assert.NotNil(t, err)
-					return
-				}
-				if !test.formatError && test.hasError {
+				if test.hasError {
 					assert.NotNil(t, err)
 					return
 				}

+ 18 - 0
core/stores/sqlx/utils.go

@@ -45,6 +45,24 @@ func escape(input string) string {
 	return b.String()
 }
 
+func formatForPrint(query string, args ...interface{}) string {
+	if len(args) == 0 {
+		return query
+	}
+
+	var vals []string
+	for _, arg := range args {
+		vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
+	}
+
+	var b strings.Builder
+	b.WriteByte('[')
+	b.WriteString(strings.Join(vals, ", "))
+	b.WriteByte(']')
+
+	return strings.Join([]string{query, b.String()}, " ")
+}
+
 func format(query string, args ...interface{}) (string, error) {
 	numArgs := len(args)
 	if numArgs == 0 {

+ 28 - 0
core/stores/sqlx/utils_test.go

@@ -28,3 +28,31 @@ func TestDesensitize_WithoutAccount(t *testing.T) {
 	datasource = desensitize(datasource)
 	assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
 }
+
+func TestFormatForPrint(t *testing.T) {
+	tests := []struct {
+		name   string
+		query  string
+		args   []interface{}
+		expect string
+	}{
+		{
+			name:   "no args",
+			query:  "select user, name from table where id=?",
+			expect: `select user, name from table where id=?`,
+		},
+		{
+			name:   "one arg",
+			query:  "select user, name from table where id=?",
+			args:   []interface{}{"kevin"},
+			expect: `select user, name from table where id=? ["kevin"]`,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			actual := formatForPrint(test.query, test.args...)
+			assert.Equal(t, test.expect, actual)
+		})
+	}
+}