Jelajahi Sumber

fix: format error should not trigger circuit breaker in sqlx (#3437)

Kevin Wan 1 tahun lalu
induk
melakukan
ff04356704

+ 2 - 3
core/mapping/unmarshaler.go

@@ -158,7 +158,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
 
 	refValue := reflect.ValueOf(mapValue)
 	if refValue.Kind() != reflect.Slice {
-		return fmt.Errorf("%s: %v", fullName, errTypeMismatch)
+		return newTypeMismatchErrorWithHint(fullName, reflect.Slice.String(), refValue.Type().String())
 	}
 	if refValue.IsNil() {
 		return nil
@@ -180,9 +180,9 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
 			continue
 		}
 
+		valid = true
 		sliceFullName := fmt.Sprintf("%s[%d]", fullName, i)
 
-		valid = true
 		switch dereffedBaseKind {
 		case reflect.Struct:
 			target := reflect.New(dereffedBaseType)
@@ -319,7 +319,6 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any,
 	for _, key := range refValue.MapKeys() {
 		keythValue := refValue.MapIndex(key)
 		keythData := keythValue.Interface()
-
 		mapFullName := fmt.Sprintf("%s[%s]", fullName, key.String())
 
 		switch dereffedElemKind {

+ 11 - 0
core/mapping/unmarshaler_test.go

@@ -5081,6 +5081,17 @@ func TestGetValueWithChainedKeys(t *testing.T) {
 	})
 }
 
+func TestUnmarshalFromStringSliceForTypeMismatch(t *testing.T) {
+	var v struct {
+		Values map[string][]string `key:"values"`
+	}
+	assert.Error(t, UnmarshalKey(map[string]any{
+		"values": map[string]any{
+			"foo": "bar",
+		},
+	}, &v))
+}
+
 func BenchmarkDefaultValue(b *testing.B) {
 	for i := 0; i < b.N; i++ {
 		var a struct {

+ 3 - 3
core/mapping/utils.go

@@ -103,21 +103,21 @@ func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 		intValue, err := strconv.ParseInt(str, 10, 64)
 		if err != nil {
-			return 0, fmt.Errorf("the value %q cannot parsed as int", str)
+			return 0, fmt.Errorf("the value %q cannot be parsed as int", str)
 		}
 
 		return intValue, nil
 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
 		uintValue, err := strconv.ParseUint(str, 10, 64)
 		if err != nil {
-			return 0, fmt.Errorf("the value %q cannot parsed as uint", str)
+			return 0, fmt.Errorf("the value %q cannot be parsed as uint", str)
 		}
 
 		return uintValue, nil
 	case reflect.Float32, reflect.Float64:
 		floatValue, err := strconv.ParseFloat(str, 64)
 		if err != nil {
-			return 0, fmt.Errorf("the value %q cannot parsed as float", str)
+			return 0, fmt.Errorf("the value %q cannot be parsed as float", str)
 		}
 
 		return floatValue, nil

+ 10 - 3
core/stores/sqlx/sqlconn.go

@@ -291,12 +291,19 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
 }
 
 func (db *commonSqlConn) acceptable(err error) bool {
-	ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
+	if err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled {
+		return true
+	}
+
+	if _, ok := err.(acceptableError); ok {
+		return true
+	}
+
 	if db.accept == nil {
-		return ok
+		return false
 	}
 
-	return ok || db.accept(err)
+	return db.accept(err)
 }
 
 func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,

+ 27 - 0
core/stores/sqlx/sqlconn_test.go

@@ -236,6 +236,33 @@ func TestStatement(t *testing.T) {
 	})
 }
 
+func TestBreakerWithFormatError(t *testing.T) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		conn := NewSqlConnFromDB(db, withMysqlAcceptable())
+		for i := 0; i < 1000; i++ {
+			var val string
+			if !assert.NotEqual(t, breaker.ErrServiceUnavailable,
+				conn.QueryRow(&val, "any ?, ?", "foo")) {
+				break
+			}
+		}
+	})
+}
+
+func TestBreakerWithScanError(t *testing.T) {
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		conn := NewSqlConnFromDB(db, withMysqlAcceptable())
+		for i := 0; i < 1000; i++ {
+			rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
+			mock.ExpectQuery("any").WillReturnRows(rows)
+			var val int
+			if !assert.NotEqual(t, breaker.ErrServiceUnavailable, conn.QueryRow(&val, "any")) {
+				break
+			}
+		}
+	})
+}
+
 func buildConn() (mock sqlmock.Sqlmock, err error) {
 	_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
 		var db *sql.DB

+ 23 - 2
core/stores/sqlx/utils.go

@@ -51,7 +51,13 @@ func escape(input string) string {
 	return b.String()
 }
 
-func format(query string, args ...any) (string, error) {
+func format(query string, args ...any) (val string, err error) {
+	defer func() {
+		if err != nil {
+			err = newAcceptableError(err)
+		}
+	}()
+
 	numArgs := len(args)
 	if numArgs == 0 {
 		return query, nil
@@ -66,7 +72,8 @@ func format(query string, args ...any) (string, error) {
 		switch ch {
 		case '?':
 			if argIndex >= numArgs {
-				return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
+				return "", fmt.Errorf("%d ? in sql, but only %d arguments provided",
+					argIndex+1, numArgs)
 			}
 
 			writeValue(&b, args[argIndex])
@@ -165,3 +172,17 @@ func writeValue(buf *strings.Builder, arg any) {
 		buf.WriteString(mapping.Repr(v))
 	}
 }
+
+type acceptableError struct {
+	err error
+}
+
+func newAcceptableError(err error) error {
+	return acceptableError{
+		err: err,
+	}
+}
+
+func (e acceptableError) Error() string {
+	return e.err.Error()
+}