فهرست منبع

feat: support array in default and options tags (#1386)

* feat: support array in default and options tags

* feat: ignore spaces in tags

* test: add more tests
Kevin Wan 3 سال پیش
والد
کامیت
23deaf50e6
4فایلهای تغییر یافته به همراه315 افزوده شده و 36 حذف شده
  1. 46 27
      core/mapping/unmarshaler.go
  2. 105 0
      core/mapping/unmarshaler_test.go
  3. 88 9
      core/mapping/utils.go
  4. 76 0
      core/mapping/utils_test.go

+ 46 - 27
core/mapping/unmarshaler.go

@@ -7,7 +7,6 @@ import (
 	"reflect"
 	"strings"
 	"sync"
-	"sync/atomic"
 	"time"
 
 	"github.com/tal-tech/go-zero/core/jsonx"
@@ -25,15 +24,17 @@ var (
 	errValueNotSettable = errors.New("value is not settable")
 	errValueNotStruct   = errors.New("value type is not struct")
 	keyUnmarshaler      = NewUnmarshaler(defaultKeyName)
-	cacheKeys           atomic.Value
-	cacheKeysLock       sync.Mutex
 	durationType        = reflect.TypeOf(time.Duration(0))
+	cacheKeys           map[string][]string
+	cacheKeysLock       sync.Mutex
+	defaultCache        map[string]interface{}
+	defaultCacheLock    sync.Mutex
 	emptyMap            = map[string]interface{}{}
 	emptyValue          = reflect.ValueOf(lang.Placeholder)
 )
 
 type (
-	// A Unmarshaler is used to unmarshal with given tag key.
+	// Unmarshaler is used to unmarshal with given tag key.
 	Unmarshaler struct {
 		key  string
 		opts unmarshalOptions
@@ -46,12 +47,11 @@ type (
 		fromString   bool
 		canonicalKey func(key string) string
 	}
-
-	keyCache map[string][]string
 )
 
 func init() {
-	cacheKeys.Store(make(keyCache))
+	cacheKeys = make(map[string][]string)
+	defaultCache = make(map[string]interface{})
 }
 
 // NewUnmarshaler returns a Unmarshaler.
@@ -388,7 +388,13 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(field reflect.StructField, v
 		if derefedType == durationType {
 			return fillDurationValue(fieldKind, value, defaultValue)
 		}
-		return setValue(fieldKind, value, defaultValue)
+
+		switch fieldKind {
+		case reflect.Array, reflect.Slice:
+			return u.fillSliceWithDefault(derefedType, value, defaultValue)
+		default:
+			return setValue(fieldKind, value, defaultValue)
+		}
 	}
 
 	switch fieldKind {
@@ -502,7 +508,8 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
 	return nil
 }
 
-func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int, baseKind reflect.Kind, value interface{}) error {
+func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
+	baseKind reflect.Kind, value interface{}) error {
 	ithVal := slice.Index(index)
 	switch v := value.(type) {
 	case json.Number:
@@ -531,6 +538,28 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int, baseKind re
 	}
 }
 
+func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value reflect.Value,
+	defaultValue string) error {
+	baseFieldType := Deref(derefedType.Elem())
+	baseFieldKind := baseFieldType.Kind()
+	defaultCacheLock.Lock()
+	slice, ok := defaultCache[defaultValue]
+	defaultCacheLock.Unlock()
+	if !ok {
+		if baseFieldKind == reflect.String {
+			slice = parseGroupedSegments(defaultValue)
+		} else if err := jsonx.UnmarshalFromString(defaultValue, &slice); err != nil {
+			return err
+		}
+
+		defaultCacheLock.Lock()
+		defaultCache[defaultValue] = slice
+		defaultCacheLock.Unlock()
+	}
+
+	return u.fillSlice(derefedType, value, slice)
+}
+
 func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue interface{}) (reflect.Value, error) {
 	mapType := reflect.MapOf(keyType, elemType)
 	valueType := reflect.TypeOf(mapValue)
@@ -724,20 +753,6 @@ func getValueWithChainedKeys(m Valuer, keys []string) (interface{}, bool) {
 	return nil, false
 }
 
-func insertKeys(key string, cache []string) {
-	cacheKeysLock.Lock()
-	defer cacheKeysLock.Unlock()
-
-	keys := cacheKeys.Load().(keyCache)
-	// copy the contents into the new map, to guarantee the old map is immutable
-	newKeys := make(keyCache)
-	for k, v := range keys {
-		newKeys[k] = v
-	}
-	newKeys[key] = cache
-	cacheKeys.Store(newKeys)
-}
-
 func join(elem ...string) string {
 	var builder strings.Builder
 
@@ -768,15 +783,19 @@ func newTypeMismatchError(name string) error {
 }
 
 func readKeys(key string) []string {
-	cache := cacheKeys.Load().(keyCache)
-	if keys, ok := cache[key]; ok {
+	cacheKeysLock.Lock()
+	keys, ok := cacheKeys[key]
+	cacheKeysLock.Unlock()
+	if ok {
 		return keys
 	}
 
-	keys := strings.FieldsFunc(key, func(c rune) bool {
+	keys = strings.FieldsFunc(key, func(c rune) bool {
 		return c == delimiter
 	})
-	insertKeys(key, keys)
+	cacheKeysLock.Lock()
+	cacheKeys[key] = keys
+	cacheKeysLock.Unlock()
 
 	return keys
 }

+ 105 - 0
core/mapping/unmarshaler_test.go

@@ -198,6 +198,66 @@ func TestUnmarshalIntWithDefault(t *testing.T) {
 	assert.Equal(t, 1, in.Int)
 }
 
+func TestUnmarshalBoolSliceWithDefault(t *testing.T) {
+	type inner struct {
+		Bools []bool `key:"bools,default=[true,false]"`
+	}
+
+	var in inner
+	assert.Nil(t, UnmarshalKey(nil, &in))
+	assert.ElementsMatch(t, []bool{true, false}, in.Bools)
+}
+
+func TestUnmarshalIntSliceWithDefault(t *testing.T) {
+	type inner struct {
+		Ints []int `key:"ints,default=[1,2,3]"`
+	}
+
+	var in inner
+	assert.Nil(t, UnmarshalKey(nil, &in))
+	assert.ElementsMatch(t, []int{1, 2, 3}, in.Ints)
+}
+
+func TestUnmarshalIntSliceWithDefaultHasSpaces(t *testing.T) {
+	type inner struct {
+		Ints []int `key:"ints,default=[1, 2, 3]"`
+	}
+
+	var in inner
+	assert.Nil(t, UnmarshalKey(nil, &in))
+	assert.ElementsMatch(t, []int{1, 2, 3}, in.Ints)
+}
+
+func TestUnmarshalFloatSliceWithDefault(t *testing.T) {
+	type inner struct {
+		Floats []float32 `key:"floats,default=[1.1,2.2,3.3]"`
+	}
+
+	var in inner
+	assert.Nil(t, UnmarshalKey(nil, &in))
+	assert.ElementsMatch(t, []float32{1.1, 2.2, 3.3}, in.Floats)
+}
+
+func TestUnmarshalStringSliceWithDefault(t *testing.T) {
+	type inner struct {
+		Strs []string `key:"strs,default=[foo,bar,woo]"`
+	}
+
+	var in inner
+	assert.Nil(t, UnmarshalKey(nil, &in))
+	assert.ElementsMatch(t, []string{"foo", "bar", "woo"}, in.Strs)
+}
+
+func TestUnmarshalStringSliceWithDefaultHasSpaces(t *testing.T) {
+	type inner struct {
+		Strs []string `key:"strs,default=[foo, bar, woo]"`
+	}
+
+	var in inner
+	assert.Nil(t, UnmarshalKey(nil, &in))
+	assert.ElementsMatch(t, []string{"foo", "bar", "woo"}, in.Strs)
+}
+
 func TestUnmarshalUint(t *testing.T) {
 	type inner struct {
 		Uint          uint   `key:"uint"`
@@ -861,10 +921,12 @@ func TestUnmarshalSliceOfStruct(t *testing.T) {
 func TestUnmarshalWithStringOptionsCorrect(t *testing.T) {
 	type inner struct {
 		Value   string `key:"value,options=first|second"`
+		Foo     string `key:"foo,options=[bar,baz]"`
 		Correct string `key:"correct,options=1|2"`
 	}
 	m := map[string]interface{}{
 		"value":   "first",
+		"foo":     "bar",
 		"correct": "2",
 	}
 
@@ -872,6 +934,7 @@ func TestUnmarshalWithStringOptionsCorrect(t *testing.T) {
 	ast := assert.New(t)
 	ast.Nil(UnmarshalKey(m, &in))
 	ast.Equal("first", in.Value)
+	ast.Equal("bar", in.Foo)
 	ast.Equal("2", in.Correct)
 }
 
@@ -943,6 +1006,22 @@ func TestUnmarshalStringOptionsWithStringOptionsIncorrect(t *testing.T) {
 	ast.NotNil(unmarshaler.Unmarshal(m, &in))
 }
 
+func TestUnmarshalStringOptionsWithStringOptionsIncorrectGrouped(t *testing.T) {
+	type inner struct {
+		Value   string `key:"value,options=[first,second]"`
+		Correct string `key:"correct,options=1|2"`
+	}
+	m := map[string]interface{}{
+		"value":   "third",
+		"correct": "2",
+	}
+
+	var in inner
+	unmarshaler := NewUnmarshaler(defaultKeyName, WithStringValues())
+	ast := assert.New(t)
+	ast.NotNil(unmarshaler.Unmarshal(m, &in))
+}
+
 func TestUnmarshalWithStringOptionsIncorrect(t *testing.T) {
 	type inner struct {
 		Value     string `key:"value,options=first|second"`
@@ -2518,3 +2597,29 @@ func TestUnmarshalJsonReaderPtrArray(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Equal(t, 3, len(res.B))
 }
+
+func TestUnmarshalJsonWithoutKey(t *testing.T) {
+	payload := `{"A": "1", "B": "2"}`
+	var res struct {
+		A string `json:""`
+		B string `json:","`
+	}
+	reader := strings.NewReader(payload)
+	err := UnmarshalJsonReader(reader, &res)
+	assert.Nil(t, err)
+	assert.Equal(t, "1", res.A)
+	assert.Equal(t, "2", res.B)
+}
+
+func BenchmarkDefaultValue(b *testing.B) {
+	for i := 0; i < b.N; i++ {
+		var a struct {
+			Ints []int    `json:"ints,default=[1,2,3]"`
+			Strs []string `json:"strs,default=[foo,bar,baz]"`
+		}
+		_ = UnmarshalJsonMap(nil, &a)
+		if len(a.Strs) != 3 || len(a.Ints) != 3 {
+			b.Fatal("failed")
+		}
+	}
+}

+ 88 - 9
core/mapping/utils.go

@@ -14,13 +14,19 @@ import (
 )
 
 const (
-	defaultOption   = "default"
-	stringOption    = "string"
-	optionalOption  = "optional"
-	optionsOption   = "options"
-	rangeOption     = "range"
-	optionSeparator = "|"
-	equalToken      = "="
+	defaultOption      = "default"
+	stringOption       = "string"
+	optionalOption     = "optional"
+	optionsOption      = "options"
+	rangeOption        = "range"
+	optionSeparator    = "|"
+	equalToken         = "="
+	escapeChar         = '\\'
+	leftBracket        = '('
+	rightBracket       = ')'
+	leftSquareBracket  = '['
+	rightSquareBracket = ']'
+	segmentSeparator   = ','
 )
 
 var (
@@ -118,7 +124,7 @@ func convertType(kind reflect.Kind, str string) (interface{}, error) {
 }
 
 func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fieldOptions, error) {
-	segments := strings.Split(value, ",")
+	segments := parseSegments(value)
 	key := strings.TrimSpace(segments[0])
 	options := segments[1:]
 
@@ -198,6 +204,16 @@ func maybeNewValue(field reflect.StructField, value reflect.Value) {
 	}
 }
 
+func parseGroupedSegments(val string) []string {
+	val = strings.TrimLeftFunc(val, func(r rune) bool {
+		return r == leftBracket || r == leftSquareBracket
+	})
+	val = strings.TrimRightFunc(val, func(r rune) bool {
+		return r == rightBracket || r == rightSquareBracket
+	})
+	return parseSegments(val)
+}
+
 // don't modify returned fieldOptions, it's cached and shared among different calls.
 func parseKeyAndOptions(tagName string, field reflect.StructField) (string, *fieldOptions, error) {
 	value := field.Tag.Get(tagName)
@@ -309,7 +325,7 @@ func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
 			return fmt.Errorf("field %s has wrong options", fieldName)
 		}
 
-		fieldOpts.Options = strings.Split(segs[1], optionSeparator)
+		fieldOpts.Options = parseOptions(segs[1])
 	case strings.HasPrefix(option, defaultOption):
 		segs := strings.Split(option, equalToken)
 		if len(segs) != 2 {
@@ -334,6 +350,69 @@ func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
 	return nil
 }
 
+// parseOptions parses the given options in tag.
+// for example: `json:"name,options=foo|bar"` or `json:"name,options=[foo,bar]"`
+func parseOptions(val string) []string {
+	if len(val) == 0 {
+		return nil
+	}
+
+	if val[0] == leftSquareBracket {
+		return parseGroupedSegments(val)
+	}
+
+	return strings.Split(val, optionSeparator)
+}
+
+func parseSegments(val string) []string {
+	var segments []string
+	var escaped, grouped bool
+	var buf strings.Builder
+
+	for _, ch := range val {
+		if escaped {
+			buf.WriteRune(ch)
+			escaped = false
+			continue
+		}
+
+		switch ch {
+		case segmentSeparator:
+			if grouped {
+				buf.WriteRune(ch)
+			} else {
+				// need to trim spaces, but we cannot ignore empty string,
+				// because the first segment stands for the key might be empty.
+				// if ignored, the later tag will be used as the key.
+				segments = append(segments, strings.TrimSpace(buf.String()))
+				buf.Reset()
+			}
+		case escapeChar:
+			if grouped {
+				buf.WriteRune(ch)
+			} else {
+				escaped = true
+			}
+		case leftBracket, leftSquareBracket:
+			buf.WriteRune(ch)
+			grouped = true
+		case rightBracket, rightSquareBracket:
+			buf.WriteRune(ch)
+			grouped = false
+		default:
+			buf.WriteRune(ch)
+		}
+	}
+
+	last := strings.TrimSpace(buf.String())
+	// ignore last empty string
+	if len(last) > 0 {
+		segments = append(segments, last)
+	}
+
+	return segments
+}
+
 func reprOfValue(val reflect.Value) string {
 	switch vt := val.Interface().(type) {
 	case bool:

+ 76 - 0
core/mapping/utils_test.go

@@ -90,6 +90,82 @@ func TestParseKeyAndOptionWithTagAndOption(t *testing.T) {
 	assert.True(t, options.FromString)
 }
 
+func TestParseSegments(t *testing.T) {
+	tests := []struct {
+		input  string
+		expect []string
+	}{
+		{
+			input:  "",
+			expect: []string{},
+		},
+		{
+			input:  ",",
+			expect: []string{""},
+		},
+		{
+			input:  "foo,",
+			expect: []string{"foo"},
+		},
+		{
+			input: ",foo",
+			// the first empty string cannot be ignored, it's the key.
+			expect: []string{"", "foo"},
+		},
+		{
+			input:  "foo",
+			expect: []string{"foo"},
+		},
+		{
+			input:  "foo,bar",
+			expect: []string{"foo", "bar"},
+		},
+		{
+			input:  "foo,bar,baz",
+			expect: []string{"foo", "bar", "baz"},
+		},
+		{
+			input:  "foo,options=a|b",
+			expect: []string{"foo", "options=a|b"},
+		},
+		{
+			input:  "foo,bar,default=[baz,qux]",
+			expect: []string{"foo", "bar", "default=[baz,qux]"},
+		},
+		{
+			input:  "foo,bar,options=[baz,qux]",
+			expect: []string{"foo", "bar", "options=[baz,qux]"},
+		},
+		{
+			input:  `foo\,bar,options=[baz,qux]`,
+			expect: []string{`foo,bar`, "options=[baz,qux]"},
+		},
+		{
+			input:  `foo,bar,options=\[baz,qux]`,
+			expect: []string{"foo", "bar", "options=[baz", "qux]"},
+		},
+		{
+			input:  `foo,bar,options=[baz\,qux]`,
+			expect: []string{"foo", "bar", `options=[baz\,qux]`},
+		},
+		{
+			input:  `foo\,bar,options=[baz,qux],default=baz`,
+			expect: []string{`foo,bar`, "options=[baz,qux]", "default=baz"},
+		},
+		{
+			input:  `foo\,bar,options=[baz,qux, quux],default=[qux, baz]`,
+			expect: []string{`foo,bar`, "options=[baz,qux, quux]", "default=[qux, baz]"},
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.input, func(t *testing.T) {
+			assert.ElementsMatch(t, test.expect, parseSegments(test.input))
+		})
+	}
+}
+
 func TestValidatePtrWithNonPtr(t *testing.T) {
 	var foo string
 	rve := reflect.ValueOf(foo)