Quellcode durchsuchen

chore: add more tests

kevin vor 2 Jahren
Ursprung
Commit
b449f2f39e
2 geänderte Dateien mit 409 neuen und 97 gelöschten Zeilen
  1. 121 91
      core/conf/config.go
  2. 288 6
      core/conf/config_test.go

+ 121 - 91
core/conf/config.go

@@ -13,18 +13,17 @@ import (
 	"github.com/zeromicro/go-zero/internal/encoding"
 )
 
-var (
-	loaders = map[string]func([]byte, any) error{
-		".json": LoadFromJsonBytes,
-		".toml": LoadFromTomlBytes,
-		".yaml": LoadFromYamlBytes,
-		".yml":  LoadFromYamlBytes,
-	}
-	emptyFieldInfo fieldInfo
-)
+var loaders = map[string]func([]byte, any) error{
+	".json": LoadFromJsonBytes,
+	".toml": LoadFromTomlBytes,
+	".yaml": LoadFromYamlBytes,
+	".yml":  LoadFromYamlBytes,
+}
 
+// children and mapField should not be both filled.
+// named fields and map cannot be bound to the same field name.
 type fieldInfo struct {
-	children map[string]fieldInfo
+	children map[string]*fieldInfo
 	mapField *fieldInfo
 }
 
@@ -60,13 +59,13 @@ func LoadConfig(file string, v any, opts ...Option) error {
 
 // LoadFromJsonBytes loads config into v from content json bytes.
 func LoadFromJsonBytes(content []byte, v any) error {
-	var m map[string]any
-	if err := jsonx.Unmarshal(content, &m); err != nil {
+	finfo, err := buildFieldsInfo(reflect.TypeOf(v))
+	if err != nil {
 		return err
 	}
 
-	finfo, err := buildFieldsInfo(reflect.TypeOf(v))
-	if err != nil {
+	var m map[string]any
+	if err := jsonx.Unmarshal(content, &m); err != nil {
 		return err
 	}
 
@@ -114,21 +113,15 @@ func MustLoad(path string, v any, opts ...Option) {
 	}
 }
 
-func addOrMergeFields(info fieldInfo, key string, child fieldInfo) error {
+func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
 	if prev, ok := info.children[key]; ok {
-		if len(child.children) == 0 && child.mapField == nil {
+		if child.mapField != nil {
 			return newDupKeyError(key)
 		}
 
-		// merge fields
-		for k, v := range child.children {
-			if _, ok = prev.children[k]; ok {
-				return newDupKeyError(k)
-			}
-
-			prev.children[k] = v
+		if err := mergeFields(prev, key, child.children); err != nil {
+			return err
 		}
-		prev.mapField = child.mapField
 	} else {
 		info.children[key] = child
 	}
@@ -136,7 +129,47 @@ func addOrMergeFields(info fieldInfo, key string, child fieldInfo) error {
 	return nil
 }
 
-func buildFieldsInfo(tp reflect.Type) (fieldInfo, error) {
+func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
+	switch ft.Kind() {
+	case reflect.Struct:
+		fields, err := buildFieldsInfo(ft)
+		if err != nil {
+			return err
+		}
+
+		for k, v := range fields.children {
+			if err = addOrMergeFields(info, k, v); err != nil {
+				return err
+			}
+		}
+	case reflect.Map:
+		elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
+		if err != nil {
+			return err
+		}
+
+		if _, ok := info.children[lowerCaseName]; ok {
+			return newDupKeyError(lowerCaseName)
+		}
+
+		info.children[lowerCaseName] = &fieldInfo{
+			children: make(map[string]*fieldInfo),
+			mapField: elemField,
+		}
+	default:
+		if _, ok := info.children[lowerCaseName]; ok {
+			return newDupKeyError(lowerCaseName)
+		}
+
+		info.children[lowerCaseName] = &fieldInfo{
+			children: make(map[string]*fieldInfo),
+		}
+	}
+
+	return nil
+}
+
+func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
 	tp = mapping.Deref(tp)
 
 	switch tp.Kind() {
@@ -145,13 +178,50 @@ func buildFieldsInfo(tp reflect.Type) (fieldInfo, error) {
 	case reflect.Array, reflect.Slice:
 		return buildFieldsInfo(mapping.Deref(tp.Elem()))
 	default:
-		return emptyFieldInfo, nil
+		return &fieldInfo{
+			children: make(map[string]*fieldInfo),
+		}, nil
+	}
+}
+
+func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
+	var finfo *fieldInfo
+	var err error
+
+	switch ft.Kind() {
+	case reflect.Struct:
+		finfo, err = buildFieldsInfo(ft)
+		if err != nil {
+			return err
+		}
+	case reflect.Array, reflect.Slice:
+		finfo, err = buildFieldsInfo(ft.Elem())
+		if err != nil {
+			return err
+		}
+	case reflect.Map:
+		elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
+		if err != nil {
+			return err
+		}
+
+		finfo = &fieldInfo{
+			children: make(map[string]*fieldInfo),
+			mapField: elemInfo,
+		}
+	default:
+		finfo, err = buildFieldsInfo(ft)
+		if err != nil {
+			return err
+		}
 	}
+
+	return addOrMergeFields(info, lowerCaseName, finfo)
 }
 
-func buildStructFieldsInfo(tp reflect.Type) (fieldInfo, error) {
-	info := fieldInfo{
-		children: make(map[string]fieldInfo),
+func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
+	info := &fieldInfo{
+		children: make(map[string]*fieldInfo),
 	}
 
 	for i := 0; i < tp.NumField(); i++ {
@@ -161,79 +231,39 @@ func buildStructFieldsInfo(tp reflect.Type) (fieldInfo, error) {
 		ft := mapping.Deref(field.Type)
 		// flatten anonymous fields
 		if field.Anonymous {
-			switch ft.Kind() {
-			case reflect.Struct:
-				fields, err := buildFieldsInfo(ft)
-				if err != nil {
-					return emptyFieldInfo, err
-				}
-				for k, v := range fields.children {
-					if err = addOrMergeFields(info, k, v); err != nil {
-						return emptyFieldInfo, err
-					}
-				}
-				info.mapField = fields.mapField
-			case reflect.Map:
-				elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
-				if err != nil {
-					return emptyFieldInfo, err
-				}
-				if _, ok := info.children[lowerCaseName]; ok {
-					return emptyFieldInfo, newDupKeyError(lowerCaseName)
-				}
-				info.children[lowerCaseName] = fieldInfo{
-					mapField: &elemField,
-				}
-			default:
-				if _, ok := info.children[lowerCaseName]; ok {
-					return emptyFieldInfo, newDupKeyError(lowerCaseName)
-				}
-				info.children[lowerCaseName] = fieldInfo{
-					children: make(map[string]fieldInfo),
-				}
+			if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
+				return nil, err
 			}
-			continue
+		} else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
+			return nil, err
 		}
+	}
 
-		var finfo fieldInfo
-		var err error
-		switch ft.Kind() {
-		case reflect.Struct:
-			finfo, err = buildFieldsInfo(ft)
-			if err != nil {
-				return emptyFieldInfo, err
-			}
-		case reflect.Array, reflect.Slice:
-			finfo, err = buildFieldsInfo(ft.Elem())
-			if err != nil {
-				return emptyFieldInfo, err
-			}
-		case reflect.Map:
-			elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
-			if err != nil {
-				return emptyFieldInfo, err
-			}
-			finfo.mapField = &elemInfo
-		default:
-			finfo, err = buildFieldsInfo(ft)
-			if err != nil {
-				return emptyFieldInfo, err
-			}
-		}
+	return info, nil
+}
+
+func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
+	if len(children) == 0 {
+		return newDupKeyError(key)
+	}
 
-		if err := addOrMergeFields(info, lowerCaseName, finfo); err != nil {
-			return emptyFieldInfo, err
+	// merge fields
+	for k, v := range children {
+		if _, ok := prev.children[k]; ok {
+			return newDupKeyError(k)
 		}
+
+		prev.children[k] = v
 	}
 
-	return info, nil
+	return nil
 }
 
 func toLowerCase(s string) string {
 	return strings.ToLower(s)
 }
 
-func toLowerCaseInterface(v any, info fieldInfo) any {
+func toLowerCaseInterface(v any, info *fieldInfo) any {
 	switch vv := v.(type) {
 	case map[string]any:
 		return toLowerCaseKeyMap(vv, info)
@@ -248,7 +278,7 @@ func toLowerCaseInterface(v any, info fieldInfo) any {
 	}
 }
 
-func toLowerCaseKeyMap(m map[string]any, info fieldInfo) map[string]any {
+func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any {
 	res := make(map[string]any)
 
 	for k, v := range m {
@@ -262,7 +292,7 @@ func toLowerCaseKeyMap(m map[string]any, info fieldInfo) map[string]any {
 		if ti, ok = info.children[lk]; ok {
 			res[lk] = toLowerCaseInterface(v, ti)
 		} else if info.mapField != nil {
-			res[k] = toLowerCaseInterface(v, *info.mapField)
+			res[k] = toLowerCaseInterface(v, info.mapField)
 		} else {
 			res[k] = v
 		}

+ 288 - 6
core/conf/config_test.go

@@ -479,11 +479,7 @@ func TestLoadFromYamlItemOverlayWithMap(t *testing.T) {
 `)
 
 	var c TestConfig
-	if assert.NoError(t, LoadFromYamlBytes(input, &c)) {
-		assert.Equal(t, "localhost", c.Server.Redis.Host)
-		assert.Equal(t, 6379, c.Server.Redis.Port)
-		assert.Equal(t, "test", c.Server.Redis.Key)
-	}
+	assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
 }
 
 func TestUnmarshalJsonBytesMap(t *testing.T) {
@@ -610,7 +606,7 @@ func TestUnmarshalJsonBytesWithMapTypeValueOfStruct(t *testing.T) {
 	}
 }
 
-func Test_checkInheritOverwrite(t *testing.T) {
+func Test_FieldOverwrite(t *testing.T) {
 	t.Run("normal", func(t *testing.T) {
 		type Base struct {
 			Name string
@@ -730,6 +726,292 @@ func Test_checkInheritOverwrite(t *testing.T) {
 	})
 }
 
+func TestFieldOverwriteComplicated(t *testing.T) {
+	t.Run("double maps", func(t *testing.T) {
+		type (
+			Base1 struct {
+				Values map[string]string
+			}
+			Base2 struct {
+				Values map[string]string
+			}
+			Config struct {
+				Base1
+				Base2
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Values": {"Key": "Value"}}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+
+	t.Run("merge children", func(t *testing.T) {
+		type (
+			Inner1 struct {
+				Name string
+			}
+			Inner2 struct {
+				Age int
+			}
+			Base1 struct {
+				Inner Inner1
+			}
+			Base2 struct {
+				Inner Inner2
+			}
+			Config struct {
+				Base1
+				Base2
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Inner": {"Name": "foo", "Age": 10}}`)
+		if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
+			assert.Equal(t, "foo", c.Base1.Inner.Name)
+			assert.Equal(t, 10, c.Base2.Inner.Age)
+		}
+	})
+
+	t.Run("overwritten maps", func(t *testing.T) {
+		type (
+			Inner struct {
+				Map map[string]string
+			}
+			Config struct {
+				Map map[string]string
+				Inner
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Inner": {"Map": {"Key": "Value"}}}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+
+	t.Run("overwritten nested maps", func(t *testing.T) {
+		type (
+			Inner struct {
+				Map map[string]string
+			}
+			Middle1 struct {
+				Map map[string]string
+				Inner
+			}
+			Middle2 struct {
+				Map map[string]string
+				Inner
+			}
+			Config struct {
+				Middle1
+				Middle2
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Middle1": {"Inner": {"Map": {"Key": "Value"}}}}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+
+	t.Run("overwritten outer/inner maps", func(t *testing.T) {
+		type (
+			Inner struct {
+				Map map[string]string
+			}
+			Middle struct {
+				Inner
+				Map map[string]string
+			}
+			Config struct {
+				Middle
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Middle": {"Inner": {"Map": {"Key": "Value"}}}}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+
+	t.Run("overwritten anonymous maps", func(t *testing.T) {
+		type (
+			Inner struct {
+				Map map[string]string
+			}
+			Middle struct {
+				Inner
+				Map map[string]string
+			}
+			Elem   map[string]Middle
+			Config struct {
+				Elem
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Elem": {"Key": {"Inner": {"Map": {"Key": "Value"}}}}}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+
+	t.Run("overwritten primitive and map", func(t *testing.T) {
+		type (
+			Inner struct {
+				Value string
+			}
+			Elem  map[string]Inner
+			Named struct {
+				Elem string
+			}
+			Config struct {
+				Named
+				Elem
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+
+	t.Run("overwritten map and slice", func(t *testing.T) {
+		type (
+			Inner struct {
+				Value string
+			}
+			Elem  []Inner
+			Named struct {
+				Elem string
+			}
+			Config struct {
+				Named
+				Elem
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+
+	t.Run("overwritten map and string", func(t *testing.T) {
+		type (
+			Elem  string
+			Named struct {
+				Elem string
+			}
+			Config struct {
+				Named
+				Elem
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+}
+
+func TestLoadNamedFieldOverwritten(t *testing.T) {
+	t.Run("overwritten named struct", func(t *testing.T) {
+		type (
+			Elem  string
+			Named struct {
+				Elem string
+			}
+			Base struct {
+				Named
+				Elem
+			}
+			Config struct {
+				Val Base
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Val": {"Elem": {"Key": {"Value": "Value"}}}}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+
+	t.Run("overwritten named []struct", func(t *testing.T) {
+		type (
+			Elem  string
+			Named struct {
+				Elem string
+			}
+			Base struct {
+				Named
+				Elem
+			}
+			Config struct {
+				Vals []Base
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+
+	t.Run("overwritten named map[string]struct", func(t *testing.T) {
+		type (
+			Elem  string
+			Named struct {
+				Elem string
+			}
+			Base struct {
+				Named
+				Elem
+			}
+			Config struct {
+				Vals map[string]Base
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Vals": {"Key": {"Elem": {"Key": {"Value": "Value"}}}}}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+
+	t.Run("overwritten named *struct", func(t *testing.T) {
+		type (
+			Elem  string
+			Named struct {
+				Elem string
+			}
+			Base struct {
+				Named
+				Elem
+			}
+			Config struct {
+				Vals *Base
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+
+	t.Run("overwritten named struct", func(t *testing.T) {
+		type (
+			Named struct {
+				Elem string
+			}
+			Base struct {
+				Named
+				Elem Named
+			}
+			Config struct {
+				Val Base
+			}
+		)
+
+		var c Config
+		input := []byte(`{"Val": {"Elem": {"Key": {"Value": "Value"}}}}`)
+		assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
+	})
+}
+
 func createTempFile(ext, text string) (string, error) {
 	tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
 	if err != nil {