Ver código fonte

fix: camel cased key of map item in config (#2715)

* fix: camel cased key of map item in config

* fix: mapping anonymous problem

* fix: mapping anonymous problem

* chore: refactor

* chore: add more tests

* chore: refactor
Kevin Wan 2 anos atrás
pai
commit
affbcb5698

+ 94 - 9
core/conf/config.go

@@ -5,6 +5,7 @@ import (
 	"log"
 	"os"
 	"path"
+	"reflect"
 	"strings"
 
 	"github.com/zeromicro/go-zero/core/jsonx"
@@ -21,6 +22,12 @@ var loaders = map[string]func([]byte, interface{}) error{
 	".yml":  LoadFromYamlBytes,
 }
 
+type fieldInfo struct {
+	name     string
+	kind     reflect.Kind
+	children map[string]fieldInfo
+}
+
 // Load loads config into v from file, .json, .yaml and .yml are acceptable.
 func Load(file string, v interface{}, opts ...Option) error {
 	content, err := os.ReadFile(file)
@@ -58,7 +65,10 @@ func LoadFromJsonBytes(content []byte, v interface{}) error {
 		return err
 	}
 
-	return mapping.UnmarshalJsonMap(toCamelCaseKeyMap(m), v, mapping.WithCanonicalKeyFunc(toCamelCase))
+	finfo := buildFieldsInfo(reflect.TypeOf(v))
+	camelCaseKeyMap := toCamelCaseKeyMap(m, finfo)
+
+	return mapping.UnmarshalJsonMap(camelCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toCamelCase))
 }
 
 // LoadConfigFromJsonBytes loads config into v from content json bytes.
@@ -100,6 +110,64 @@ func MustLoad(path string, v interface{}, opts ...Option) {
 	}
 }
 
+func buildFieldsInfo(tp reflect.Type) map[string]fieldInfo {
+	tp = mapping.Deref(tp)
+
+	switch tp.Kind() {
+	case reflect.Struct:
+		return buildStructFieldsInfo(tp)
+	case reflect.Array, reflect.Slice:
+		return buildFieldsInfo(mapping.Deref(tp.Elem()))
+	default:
+		return nil
+	}
+}
+
+func buildStructFieldsInfo(tp reflect.Type) map[string]fieldInfo {
+	info := make(map[string]fieldInfo)
+
+	for i := 0; i < tp.NumField(); i++ {
+		field := tp.Field(i)
+		name := field.Name
+		ccName := toCamelCase(name)
+		ft := mapping.Deref(field.Type)
+
+		// flatten anonymous fields
+		if field.Anonymous {
+			if ft.Kind() == reflect.Struct {
+				fields := buildFieldsInfo(ft)
+				for k, v := range fields {
+					info[k] = v
+				}
+			} else {
+				info[ccName] = fieldInfo{
+					name: name,
+					kind: ft.Kind(),
+				}
+			}
+			continue
+		}
+
+		var fields map[string]fieldInfo
+		switch ft.Kind() {
+		case reflect.Struct:
+			fields = buildFieldsInfo(ft)
+		case reflect.Array, reflect.Slice:
+			fields = buildFieldsInfo(ft.Elem())
+		case reflect.Map:
+			fields = buildFieldsInfo(ft.Elem())
+		}
+
+		info[ccName] = fieldInfo{
+			name:     name,
+			kind:     ft.Kind(),
+			children: fields,
+		}
+	}
+
+	return info
+}
+
 func toCamelCase(s string) string {
 	var buf strings.Builder
 	buf.Grow(len(s))
@@ -123,14 +191,19 @@ func toCamelCase(s string) string {
 		if isCap || isLow {
 			buf.WriteRune(v)
 			capNext = false
-		} else if v == ' ' || v == '\t' {
+			continue
+		}
+
+		switch v {
+		// '.' is used for chained keys, e.g. "grand.parent.child"
+		case ' ', '.', '\t':
 			buf.WriteRune(v)
 			capNext = false
 			boundary = true
-		} else if v == '_' {
+		case '_':
 			capNext = true
 			boundary = true
-		} else {
+		default:
 			buf.WriteRune(v)
 			capNext = true
 		}
@@ -139,14 +212,14 @@ func toCamelCase(s string) string {
 	return buf.String()
 }
 
-func toCamelCaseInterface(v interface{}) interface{} {
+func toCamelCaseInterface(v interface{}, info map[string]fieldInfo) interface{} {
 	switch vv := v.(type) {
 	case map[string]interface{}:
-		return toCamelCaseKeyMap(vv)
+		return toCamelCaseKeyMap(vv, info)
 	case []interface{}:
 		var arr []interface{}
 		for _, vvv := range vv {
-			arr = append(arr, toCamelCaseInterface(vvv))
+			arr = append(arr, toCamelCaseInterface(vvv, info))
 		}
 		return arr
 	default:
@@ -154,10 +227,22 @@ func toCamelCaseInterface(v interface{}) interface{} {
 	}
 }
 
-func toCamelCaseKeyMap(m map[string]interface{}) map[string]interface{} {
+func toCamelCaseKeyMap(m map[string]interface{}, info map[string]fieldInfo) map[string]interface{} {
 	res := make(map[string]interface{})
+
 	for k, v := range m {
-		res[toCamelCase(k)] = toCamelCaseInterface(v)
+		ti, ok := info[k]
+		if ok {
+			res[k] = toCamelCaseInterface(v, ti.children)
+			continue
+		}
+
+		cck := toCamelCase(k)
+		if ti, ok = info[cck]; ok {
+			res[toCamelCase(k)] = toCamelCaseInterface(v, ti.children)
+		} else {
+			res[k] = v
+		}
 	}
 
 	return res

+ 82 - 0
core/conf/config_test.go

@@ -283,6 +283,10 @@ func TestToCamelCase(t *testing.T) {
 			input:  "Hello World Foo_Bar",
 			expect: "hello world fooBar",
 		},
+		{
+			input:  "Hello.World Foo_Bar",
+			expect: "hello.world fooBar",
+		},
 		{
 			input:  "你好 World Foo_Bar",
 			expect: "你好 world fooBar",
@@ -328,6 +332,84 @@ func TestLoadFromYamlBytes(t *testing.T) {
 	assert.Equal(t, "foo", val.Layer1.Layer2.Layer3)
 }
 
+func TestLoadFromYamlBytesLayers(t *testing.T) {
+	input := []byte(`layer1:
+  layer2:
+    layer3: foo`)
+	var val struct {
+		Value string `json:"Layer1.Layer2.Layer3"`
+	}
+
+	assert.NoError(t, LoadFromYamlBytes(input, &val))
+	assert.Equal(t, "foo", val.Value)
+}
+
+func TestUnmarshalJsonBytesMap(t *testing.T) {
+	input := []byte(`{"foo":{"/mtproto.RPCTos": "bff.bff","bar":"baz"}}`)
+
+	var val struct {
+		Foo map[string]string
+	}
+
+	assert.NoError(t, LoadFromJsonBytes(input, &val))
+	assert.Equal(t, "bff.bff", val.Foo["/mtproto.RPCTos"])
+	assert.Equal(t, "baz", val.Foo["bar"])
+}
+
+func TestUnmarshalJsonBytesMapWithSliceElements(t *testing.T) {
+	input := []byte(`{"foo":{"/mtproto.RPCTos": ["bff.bff", "any"],"bar":["baz", "qux"]}}`)
+
+	var val struct {
+		Foo map[string][]string
+	}
+
+	assert.NoError(t, LoadFromJsonBytes(input, &val))
+	assert.EqualValues(t, []string{"bff.bff", "any"}, val.Foo["/mtproto.RPCTos"])
+	assert.EqualValues(t, []string{"baz", "qux"}, val.Foo["bar"])
+}
+
+func TestUnmarshalJsonBytesMapWithSliceOfStructs(t *testing.T) {
+	input := []byte(`{"foo":{
+	"/mtproto.RPCTos": [{"bar": "any"}],
+	"bar":[{"bar": "qux"}, {"bar": "ever"}]}}`)
+
+	var val struct {
+		Foo map[string][]struct {
+			Bar string
+		}
+	}
+
+	assert.NoError(t, LoadFromJsonBytes(input, &val))
+	assert.Equal(t, 1, len(val.Foo["/mtproto.RPCTos"]))
+	assert.Equal(t, "any", val.Foo["/mtproto.RPCTos"][0].Bar)
+	assert.Equal(t, 2, len(val.Foo["bar"]))
+	assert.Equal(t, "qux", val.Foo["bar"][0].Bar)
+	assert.Equal(t, "ever", val.Foo["bar"][1].Bar)
+}
+
+func TestUnmarshalJsonBytesWithAnonymousField(t *testing.T) {
+	type (
+		Int int
+
+		InnerConf struct {
+			Name string
+		}
+
+		Conf struct {
+			Int
+			InnerConf
+		}
+	)
+
+	var (
+		input = []byte(`{"Name": "hello", "int": 3}`)
+		c     Conf
+	)
+	assert.NoError(t, LoadFromJsonBytes(input, &c))
+	assert.Equal(t, "hello", c.Name)
+	assert.Equal(t, Int(3), c.Int)
+}
+
 func createTempFile(ext, text string) (string, error) {
 	tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
 	if err != nil {

+ 39 - 22
core/mapping/unmarshaler.go

@@ -376,19 +376,51 @@ func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value ref
 		return err
 	}
 
-	if _, hasValue := getValue(m, key); hasValue {
-		return fmt.Errorf("fields of %s can't be wrapped inside, because it's anonymous", key)
-	}
-
 	if options.optional() {
-		return u.processAnonymousFieldOptional(field.Type, value, key, m, fullName)
+		return u.processAnonymousFieldOptional(field, value, key, m, fullName)
 	}
 
-	return u.processAnonymousFieldRequired(field.Type, value, m, fullName)
+	return u.processAnonymousFieldRequired(field, value, m, fullName)
 }
 
-func (u *Unmarshaler) processAnonymousFieldOptional(fieldType reflect.Type, value reflect.Value,
+func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value,
 	key string, m valuerWithParent, fullName string) error {
+	derefedFieldType := Deref(field.Type)
+
+	switch derefedFieldType.Kind() {
+	case reflect.Struct:
+		return u.processAnonymousStructFieldOptional(field.Type, value, key, m, fullName)
+	default:
+		return u.processNamedField(field, value, m, fullName)
+	}
+}
+
+func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, value reflect.Value,
+	m valuerWithParent, fullName string) error {
+	fieldType := field.Type
+	maybeNewValue(fieldType, value)
+	derefedFieldType := Deref(fieldType)
+	indirectValue := reflect.Indirect(value)
+
+	switch derefedFieldType.Kind() {
+	case reflect.Struct:
+		for i := 0; i < derefedFieldType.NumField(); i++ {
+			if err := u.processField(derefedFieldType.Field(i), indirectValue.Field(i),
+				m, fullName); err != nil {
+				return err
+			}
+		}
+	default:
+		if err := u.processNamedField(field, indirectValue, m, fullName); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type,
+	value reflect.Value, key string, m valuerWithParent, fullName string) error {
 	var filled bool
 	var required int
 	var requiredFilled int
@@ -428,21 +460,6 @@ func (u *Unmarshaler) processAnonymousFieldOptional(fieldType reflect.Type, valu
 	return nil
 }
 
-func (u *Unmarshaler) processAnonymousFieldRequired(fieldType reflect.Type, value reflect.Value,
-	m valuerWithParent, fullName string) error {
-	maybeNewValue(fieldType, value)
-	derefedFieldType := Deref(fieldType)
-	indirectValue := reflect.Indirect(value)
-
-	for i := 0; i < derefedFieldType.NumField(); i++ {
-		if err := u.processField(derefedFieldType.Field(i), indirectValue.Field(i), m, fullName); err != nil {
-			return err
-		}
-	}
-
-	return nil
-}
-
 func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value,
 	m valuerWithParent, fullName string) error {
 	if usingDifferentKeys(u.key, field) {

+ 143 - 0
core/mapping/unmarshaler_test.go

@@ -212,6 +212,24 @@ func TestUnmarshalIntPtr(t *testing.T) {
 	assert.Equal(t, 1, *in.Int)
 }
 
+func TestUnmarshalIntSliceOfPtr(t *testing.T) {
+	type inner struct {
+		Ints []*int `key:"ints"`
+	}
+	m := map[string]interface{}{
+		"ints": []int{1, 2, 3},
+	}
+
+	var in inner
+	assert.NoError(t, UnmarshalKey(m, &in))
+	assert.NotEmpty(t, in.Ints)
+	var ints []int
+	for _, i := range in.Ints {
+		ints = append(ints, *i)
+	}
+	assert.EqualValues(t, []int{1, 2, 3}, ints)
+}
+
 func TestUnmarshalIntWithDefault(t *testing.T) {
 	type inner struct {
 		Int int `key:"int,default=5"`
@@ -3665,6 +3683,7 @@ func TestUnmarshalJsonBytesSliceOfMaps(t *testing.T) {
 			Name         string `json:"name"`
 			ActualAmount int    `json:"actual_amount"`
 		}
+
 		OrderApplyRefundReq struct {
 			OrderId       string            `json:"order_id"`
 			RefundReason  RefundReasonData  `json:"refund_reason,optional"`
@@ -3676,6 +3695,130 @@ func TestUnmarshalJsonBytesSliceOfMaps(t *testing.T) {
 	assert.NoError(t, UnmarshalJsonBytes(input, &req))
 }
 
+func TestUnmarshalJsonBytesWithAnonymousField(t *testing.T) {
+	type (
+		Int int
+
+		InnerConf struct {
+			Name string
+		}
+
+		Conf struct {
+			Int
+			InnerConf
+		}
+	)
+
+	var (
+		input = []byte(`{"Name": "hello", "Int": 3}`)
+		c     Conf
+	)
+	assert.NoError(t, UnmarshalJsonBytes(input, &c))
+	assert.Equal(t, "hello", c.Name)
+	assert.Equal(t, Int(3), c.Int)
+}
+
+func TestUnmarshalJsonBytesWithAnonymousFieldOptional(t *testing.T) {
+	type (
+		Int int
+
+		InnerConf struct {
+			Name string
+		}
+
+		Conf struct {
+			Int `json:",optional"`
+			InnerConf
+		}
+	)
+
+	var (
+		input = []byte(`{"Name": "hello", "Int": 3}`)
+		c     Conf
+	)
+	assert.NoError(t, UnmarshalJsonBytes(input, &c))
+	assert.Equal(t, "hello", c.Name)
+	assert.Equal(t, Int(3), c.Int)
+}
+
+func TestUnmarshalJsonBytesWithAnonymousFieldBadTag(t *testing.T) {
+	type (
+		Int int
+
+		InnerConf struct {
+			Name string
+		}
+
+		Conf struct {
+			Int `json:",optional=123"`
+			InnerConf
+		}
+	)
+
+	var (
+		input = []byte(`{"Name": "hello", "Int": 3}`)
+		c     Conf
+	)
+	assert.Error(t, UnmarshalJsonBytes(input, &c))
+}
+
+func TestUnmarshalJsonBytesWithAnonymousFieldBadValue(t *testing.T) {
+	type (
+		Int int
+
+		InnerConf struct {
+			Name string
+		}
+
+		Conf struct {
+			Int
+			InnerConf
+		}
+	)
+
+	var (
+		input = []byte(`{"Name": "hello", "Int": "3"}`)
+		c     Conf
+	)
+	assert.Error(t, UnmarshalJsonBytes(input, &c))
+}
+
+func TestUnmarshalJsonBytesWithAnonymousFieldBadTagInStruct(t *testing.T) {
+	type (
+		InnerConf struct {
+			Name string `json:",optional=123"`
+		}
+
+		Conf struct {
+			InnerConf `json:",optional"`
+		}
+	)
+
+	var (
+		input = []byte(`{"Name": "hello"}`)
+		c     Conf
+	)
+	assert.Error(t, UnmarshalJsonBytes(input, &c))
+}
+
+func TestUnmarshalJsonBytesWithAnonymousFieldNotInOptions(t *testing.T) {
+	type (
+		InnerConf struct {
+			Name string `json:",options=[a,b]"`
+		}
+
+		Conf struct {
+			InnerConf `json:",optional"`
+		}
+	)
+
+	var (
+		input = []byte(`{"Name": "hello"}`)
+		c     Conf
+	)
+	assert.Error(t, UnmarshalJsonBytes(input, &c))
+}
+
 func BenchmarkDefaultValue(b *testing.B) {
 	for i := 0; i < b.N; i++ {
 		var a struct {

+ 24 - 0
core/mapping/valuer_test.go

@@ -31,3 +31,27 @@ func TestMapValuerWithInherit_Value(t *testing.T) {
 	assert.Equal(t, "localhost", m["host"])
 	assert.Equal(t, 8080, m["port"])
 }
+
+func TestRecursiveValuer_Value(t *testing.T) {
+	input := map[string]interface{}{
+		"component": map[string]interface{}{
+			"name": "test",
+			"foo": map[string]interface{}{
+				"bar": "baz",
+			},
+		},
+		"foo": "value",
+	}
+	valuer := recursiveValuer{
+		current: mapValuer(input["component"].(map[string]interface{})),
+		parent: simpleValuer{
+			current: mapValuer(input),
+		},
+	}
+
+	val, ok := valuer.Value("foo")
+	assert.True(t, ok)
+	assert.EqualValues(t, map[string]interface{}{
+		"bar": "baz",
+	}, val)
+}