Pārlūkot izejas kodu

feat: check key overwritten

kevin 2 gadi atpakaļ
vecāks
revīzija
c57b0b8f90
2 mainītis faili ar 203 papildinājumiem un 29 dzēšanām
  1. 79 19
      core/conf/config.go
  2. 124 10
      core/conf/config_test.go

+ 79 - 19
core/conf/config.go

@@ -13,12 +13,15 @@ import (
 	"github.com/zeromicro/go-zero/internal/encoding"
 )
 
-var loaders = map[string]func([]byte, any) error{
-	".json": LoadFromJsonBytes,
-	".toml": LoadFromTomlBytes,
-	".yaml": LoadFromYamlBytes,
-	".yml":  LoadFromYamlBytes,
-}
+var (
+	loaders = map[string]func([]byte, any) error{
+		".json": LoadFromJsonBytes,
+		".toml": LoadFromTomlBytes,
+		".yaml": LoadFromYamlBytes,
+		".yml":  LoadFromYamlBytes,
+	}
+	emptyFieldInfo fieldInfo
+)
 
 type fieldInfo struct {
 	children map[string]fieldInfo
@@ -62,7 +65,11 @@ func LoadFromJsonBytes(content []byte, v any) error {
 		return err
 	}
 
-	finfo := buildFieldsInfo(reflect.TypeOf(v))
+	finfo, err := buildFieldsInfo(reflect.TypeOf(v))
+	if err != nil {
+		return err
+	}
+
 	lowerCaseKeyMap := toLowerCaseKeyMap(m, finfo)
 
 	return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
@@ -107,19 +114,29 @@ func MustLoad(path string, v any, opts ...Option) {
 	}
 }
 
-func addOrMergeFields(info fieldInfo, key string, child fieldInfo) {
+func addOrMergeFields(info fieldInfo, key string, child fieldInfo) error {
 	if prev, ok := info.children[key]; ok {
+		if len(child.children) == 0 && child.mapField == nil {
+			return newDupKeyError(key)
+		}
+
 		// merge fields
 		for k, v := range child.children {
+			if _, ok = prev.children[k]; ok {
+				return newDupKeyError(k)
+			}
+
 			prev.children[k] = v
 		}
 		prev.mapField = child.mapField
 	} else {
 		info.children[key] = child
 	}
+
+	return nil
 }
 
-func buildFieldsInfo(tp reflect.Type) fieldInfo {
+func buildFieldsInfo(tp reflect.Type) (fieldInfo, error) {
 	tp = mapping.Deref(tp)
 
 	switch tp.Kind() {
@@ -128,11 +145,11 @@ func buildFieldsInfo(tp reflect.Type) fieldInfo {
 	case reflect.Array, reflect.Slice:
 		return buildFieldsInfo(mapping.Deref(tp.Elem()))
 	default:
-		return fieldInfo{}
+		return emptyFieldInfo, nil
 	}
 }
 
-func buildStructFieldsInfo(tp reflect.Type) fieldInfo {
+func buildStructFieldsInfo(tp reflect.Type) (fieldInfo, error) {
 	info := fieldInfo{
 		children: make(map[string]fieldInfo),
 	}
@@ -146,17 +163,31 @@ func buildStructFieldsInfo(tp reflect.Type) fieldInfo {
 		if field.Anonymous {
 			switch ft.Kind() {
 			case reflect.Struct:
-				fields := buildFieldsInfo(ft)
+				fields, err := buildFieldsInfo(ft)
+				if err != nil {
+					return emptyFieldInfo, err
+				}
 				for k, v := range fields.children {
-					addOrMergeFields(info, k, v)
+					if err = addOrMergeFields(info, k, v); err != nil {
+						return emptyFieldInfo, err
+					}
 				}
 				info.mapField = fields.mapField
 			case reflect.Map:
-				elemField := buildFieldsInfo(mapping.Deref(ft.Elem()))
+				elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
+				if err != nil {
+					return emptyFieldInfo, err
+				}
+				if _, ok := info.children[lowerCaseName]; ok {
+					return emptyFieldInfo, newDupKeyError(lowerCaseName)
+				}
 				info.children[lowerCaseName] = fieldInfo{
 					mapField: &elemField,
 				}
 			default:
+				if _, ok := info.children[lowerCaseName]; ok {
+					return emptyFieldInfo, newDupKeyError(lowerCaseName)
+				}
 				info.children[lowerCaseName] = fieldInfo{
 					children: make(map[string]fieldInfo),
 				}
@@ -165,20 +196,37 @@ func buildStructFieldsInfo(tp reflect.Type) fieldInfo {
 		}
 
 		var finfo fieldInfo
+		var err error
 		switch ft.Kind() {
 		case reflect.Struct:
-			finfo = buildFieldsInfo(ft)
+			finfo, err = buildFieldsInfo(ft)
+			if err != nil {
+				return emptyFieldInfo, err
+			}
 		case reflect.Array, reflect.Slice:
-			finfo = buildFieldsInfo(ft.Elem())
+			finfo, err = buildFieldsInfo(ft.Elem())
+			if err != nil {
+				return emptyFieldInfo, err
+			}
 		case reflect.Map:
-			elemInfo := buildFieldsInfo(mapping.Deref(ft.Elem()))
+			elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
+			if err != nil {
+				return emptyFieldInfo, err
+			}
 			finfo.mapField = &elemInfo
+		default:
+			finfo, err = buildFieldsInfo(ft)
+			if err != nil {
+				return emptyFieldInfo, err
+			}
 		}
 
-		addOrMergeFields(info, lowerCaseName, finfo)
+		if err := addOrMergeFields(info, lowerCaseName, finfo); err != nil {
+			return emptyFieldInfo, err
+		}
 	}
 
-	return info
+	return info, nil
 }
 
 func toLowerCase(s string) string {
@@ -222,3 +270,15 @@ func toLowerCaseKeyMap(m map[string]any, info fieldInfo) map[string]any {
 
 	return res
 }
+
+type dupKeyError struct {
+	key string
+}
+
+func newDupKeyError(key string) dupKeyError {
+	return dupKeyError{key: key}
+}
+
+func (e dupKeyError) Error() string {
+	return fmt.Sprintf("duplicated key %s", e.key)
+}

+ 124 - 10
core/conf/config_test.go

@@ -9,6 +9,8 @@ import (
 	"github.com/zeromicro/go-zero/core/hash"
 )
 
+var dupErr dupKeyError
+
 func TestLoadConfig_notExists(t *testing.T) {
 	assert.NotNil(t, Load("not_a_file", nil))
 }
@@ -413,11 +415,7 @@ func TestLoadFromYamlItemOverlay(t *testing.T) {
 `)
 
 	var c TestConfig
-	if assert.NoError(t, LoadFromYamlBytes(input, &c)) {
-		assert.Equal(t, "localhost", c.Redis.Host)
-		assert.Equal(t, 6379, c.Redis.Port)
-		assert.Equal(t, "test", c.Server.Redis.Key)
-	}
+	assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
 }
 
 func TestLoadFromYamlItemOverlayReverse(t *testing.T) {
@@ -449,11 +447,7 @@ func TestLoadFromYamlItemOverlayReverse(t *testing.T) {
 `)
 
 	var c TestConfig
-	if assert.NoError(t, LoadFromYamlBytes(input, &c)) {
-		assert.Equal(t, "localhost", c.Redis.Host)
-		assert.Equal(t, 6379, c.Redis.Port)
-		assert.Equal(t, "test", c.Redis.Key)
-	}
+	assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
 }
 
 func TestLoadFromYamlItemOverlayWithMap(t *testing.T) {
@@ -616,6 +610,126 @@ func TestUnmarshalJsonBytesWithMapTypeValueOfStruct(t *testing.T) {
 	}
 }
 
+func Test_checkInheritOverwrite(t *testing.T) {
+	t.Run("normal", func(t *testing.T) {
+		type Base struct {
+			Name string
+		}
+
+		type St1 struct {
+			Base
+			Name2 string
+		}
+
+		type St2 struct {
+			Base
+			Name2 string
+		}
+
+		type St3 struct {
+			*Base
+			Name2 string
+		}
+
+		type St4 struct {
+			*Base
+			Name2 *string
+		}
+
+		validate := func(val any) {
+			input := []byte(`{"Name": "hello", "Name2": "world"}`)
+			assert.NoError(t, LoadFromJsonBytes(input, val))
+		}
+
+		validate(&St1{})
+		validate(&St2{})
+		validate(&St3{})
+		validate(&St4{})
+	})
+
+	t.Run("Inherit Override", func(t *testing.T) {
+		type Base struct {
+			Name string
+		}
+
+		type St1 struct {
+			Base
+			Name string
+		}
+
+		type St2 struct {
+			Base
+			Name int
+		}
+
+		type St3 struct {
+			*Base
+			Name int
+		}
+
+		type St4 struct {
+			*Base
+			Name *string
+		}
+
+		validate := func(val any) {
+			input := []byte(`{"Name": "hello"}`)
+			err := LoadFromJsonBytes(input, val)
+			assert.ErrorAs(t, err, &dupErr)
+			assert.Equal(t, newDupKeyError("name").Error(), err.Error())
+		}
+
+		validate(&St1{})
+		validate(&St2{})
+		validate(&St3{})
+		validate(&St4{})
+	})
+
+	t.Run("Inherit more", func(t *testing.T) {
+		type Base1 struct {
+			Name string
+		}
+
+		type St0 struct {
+			Base1
+			Name string
+		}
+
+		type St1 struct {
+			St0
+			Name string
+		}
+
+		type St2 struct {
+			St0
+			Name int
+		}
+
+		type St3 struct {
+			*St0
+			Name int
+		}
+
+		type St4 struct {
+			*St0
+			Name *int
+		}
+
+		validate := func(val any) {
+			input := []byte(`{"Name": "hello"}`)
+			err := LoadFromJsonBytes(input, val)
+			assert.ErrorAs(t, err, &dupErr)
+			assert.Equal(t, newDupKeyError("name").Error(), err.Error())
+		}
+
+		validate(&St0{})
+		validate(&St1{})
+		validate(&St2{})
+		validate(&St3{})
+		validate(&St4{})
+	})
+}
+
 func createTempFile(ext, text string) (string, error) {
 	tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
 	if err != nil {