Explorar o código

fix #1806 (#1833)

* fix #1806

* chore: refine error text
Kevin Wan %!s(int64=3) %!d(string=hai) anos
pai
achega
5bcee4cf7c

+ 1 - 1
core/logx/logs.go

@@ -275,7 +275,7 @@ func Infov(v interface{}) {
 	infoAnySync(v)
 }
 
-// Must checks if err is nil, otherwise logs the err and exits.
+// Must checks if err is nil, otherwise logs the error and exits.
 func Must(err error) {
 	if err != nil {
 		msg := formatWithCaller(err.Error(), 3)

+ 27 - 0
core/stores/sqlx/utils.go

@@ -2,6 +2,7 @@ package sqlx
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"strconv"
 	"strings"
@@ -10,6 +11,8 @@ import (
 	"github.com/zeromicro/go-zero/core/mapping"
 )
 
+var errUnbalancedEscape = errors.New("no char after escape char")
+
 func desensitize(datasource string) string {
 	// remove account
 	pos := strings.LastIndex(datasource, "@")
@@ -95,6 +98,30 @@ func format(query string, args ...interface{}) (string, error) {
 				writeValue(&b, args[index])
 				i = j - 1
 			}
+		case '\'', '"', '`':
+			b.WriteByte(ch)
+			for j := i + 1; j < bytes; j++ {
+				cur := query[j]
+				b.WriteByte(cur)
+
+				switch cur {
+				case '\\':
+					j++
+					if j >= bytes {
+						return "", errUnbalancedEscape
+					}
+
+					b.WriteByte(query[j])
+				case '\'', '"', '`':
+					if cur == ch {
+						i = j
+						goto end
+					}
+				}
+			}
+
+		end:
+			break
 		default:
 			b.WriteByte(ch)
 		}

+ 25 - 0
core/stores/sqlx/utils_test.go

@@ -97,6 +97,30 @@ func TestFormat(t *testing.T) {
 			args:   []interface{}{"133", false},
 			hasErr: true,
 		},
+		{
+			name:   "select with date",
+			query:  "select * from user where date='2006-01-02 15:04:05' and name=:1",
+			args:   []interface{}{"foo"},
+			expect: "select * from user where date='2006-01-02 15:04:05' and name='foo'",
+		},
+		{
+			name:   "select with date and escape",
+			query:  `select * from user where date=' 2006-01-02 15:04:05 \'' and name=:1`,
+			args:   []interface{}{"foo"},
+			expect: `select * from user where date=' 2006-01-02 15:04:05 \'' and name='foo'`,
+		},
+		{
+			name:   "select with date and bad arg",
+			query:  `select * from user where date='2006-01-02 15:04:05 \'' and name=:a`,
+			args:   []interface{}{"foo"},
+			hasErr: true,
+		},
+		{
+			name:   "select with date and escape error",
+			query:  `select * from user where date='2006-01-02 15:04:05 \`,
+			args:   []interface{}{"foo"},
+			hasErr: true,
+		},
 	}
 
 	for _, test := range tests {
@@ -108,6 +132,7 @@ func TestFormat(t *testing.T) {
 			if test.hasErr {
 				assert.NotNil(t, err)
 			} else {
+				assert.Nil(t, err)
 				assert.Equal(t, test.expect, actual)
 			}
 		})

+ 4 - 7
tools/goctl/model/sql/parser/parser.go

@@ -69,7 +69,6 @@ func Parse(filename, database string) ([]*Table, error) {
 	}
 
 	nameOriginals := parseNameOriginal(tables)
-
 	indexNameGen := func(column ...string) string {
 		return strings.Join(column, "_")
 	}
@@ -77,14 +76,12 @@ func Parse(filename, database string) ([]*Table, error) {
 	prefix := filepath.Base(filename)
 	var list []*Table
 	for indexTable, e := range tables {
-		columns := e.Columns
-
 		var (
+			primaryColumn    string
 			primaryColumnSet = collection.NewSet()
-
-			primaryColumn string
-			uniqueKeyMap  = make(map[string][]string)
-			normalKeyMap  = make(map[string][]string)
+			uniqueKeyMap     = make(map[string][]string)
+			normalKeyMap     = make(map[string][]string)
+			columns          = e.Columns
 		)
 
 		for _, column := range columns {