浏览代码

feat: support ptr of ptr of ... in mapping (#2779)

* feat: support ptr of ptr of ... in mapping

* feat: support ptr of ptr of time.Duration in mapping

* feat: support ptr of ptr of json.Number in mapping

* chore: improve setting in mapping

* feat: support ptr of ptr encoding.TextUnmarshaler in mapping

* chore: add more tests

* fix: string ptr

* chore: update tests
Kevin Wan 2 年之前
父节点
当前提交
367afb544c

+ 31 - 43
core/mapping/unmarshaler.go

@@ -148,7 +148,6 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
 	}
 
 	baseType := fieldType.Elem()
-	baseKind := baseType.Kind()
 	dereffedBaseType := Deref(baseType)
 	dereffedBaseKind := dereffedBaseType.Kind()
 	refValue := reflect.ValueOf(mapValue)
@@ -177,11 +176,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
 				return err
 			}
 
-			if baseKind == reflect.Ptr {
-				conv.Index(i).Set(target)
-			} else {
-				conv.Index(i).Set(target.Elem())
-			}
+			SetValue(fieldType.Elem(), conv.Index(i), target.Elem())
 		case reflect.Slice:
 			if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue); err != nil {
 				return err
@@ -235,9 +230,9 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
 	ithVal := slice.Index(index)
 	switch v := value.(type) {
 	case fmt.Stringer:
-		return setValue(baseKind, ithVal, v.String())
+		return setValueFromString(baseKind, ithVal, v.String())
 	case string:
-		return setValue(baseKind, ithVal, v)
+		return setValueFromString(baseKind, ithVal, v)
 	case map[string]interface{}:
 		return u.fillMap(ithVal.Type(), ithVal, value)
 	default:
@@ -251,7 +246,7 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
 
 			target := reflect.New(baseType).Elem()
 			target.Set(reflect.ValueOf(value))
-			ithVal.Set(target.Addr())
+			SetValue(ithVal.Type(), ithVal, target)
 			return nil
 		}
 
@@ -295,7 +290,6 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
 
 	refValue := reflect.ValueOf(mapValue)
 	targetValue := reflect.MakeMapWithSize(mapType, refValue.Len())
-	fieldElemKind := elemType.Kind()
 	dereffedElemType := Deref(elemType)
 	dereffedElemKind := dereffedElemType.Kind()
 
@@ -322,11 +316,7 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
 				return emptyValue, err
 			}
 
-			if fieldElemKind == reflect.Ptr {
-				targetValue.SetMapIndex(key, target)
-			} else {
-				targetValue.SetMapIndex(key, target.Elem())
-			}
+			SetMapIndexValue(elemType, targetValue, key, target.Elem())
 		case reflect.Map:
 			keythMap, ok := keythData.(map[string]interface{})
 			if !ok {
@@ -355,7 +345,7 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
 				targetValue.SetMapIndex(key, reflect.ValueOf(v))
 			case json.Number:
 				target := reflect.New(dereffedElemType)
-				if err := setValue(dereffedElemKind, target.Elem(), v.String()); err != nil {
+				if err := setValueFromString(dereffedElemKind, target.Elem(), v.String()); err != nil {
 					return emptyValue, err
 				}
 
@@ -519,7 +509,7 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
 	case valueKind == reflect.String && typeKind == reflect.Slice:
 		return u.fillSliceFromString(fieldType, value, mapValue)
 	case valueKind == reflect.String && derefedFieldType == durationType:
-		return fillDurationValue(fieldType.Kind(), value, mapValue.(string))
+		return fillDurationValue(fieldType, value, mapValue.(string))
 	default:
 		return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
 	}
@@ -555,8 +545,8 @@ func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflec
 
 func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type, value reflect.Value,
 	v json.Number, opts *fieldOptionsWithContext, fullName string) error {
-	fieldKind := fieldType.Kind()
-	typeKind := Deref(fieldType).Kind()
+	baseType := Deref(fieldType)
+	typeKind := baseType.Kind()
 
 	if err := validateJsonNumberRange(v, opts); err != nil {
 		return err
@@ -566,9 +556,7 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
 		return err
 	}
 
-	if fieldKind == reflect.Ptr {
-		value = value.Elem()
-	}
+	target := reflect.New(Deref(fieldType)).Elem()
 
 	switch typeKind {
 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -577,7 +565,7 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
 			return err
 		}
 
-		value.SetInt(iValue)
+		target.SetInt(iValue)
 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
 		iValue, err := v.Int64()
 		if err != nil {
@@ -588,18 +576,20 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
 			return fmt.Errorf("unmarshal %q with bad value %q", fullName, v.String())
 		}
 
-		value.SetUint(uint64(iValue))
+		target.SetUint(uint64(iValue))
 	case reflect.Float32, reflect.Float64:
 		fValue, err := v.Float64()
 		if err != nil {
 			return err
 		}
 
-		value.SetFloat(fValue)
+		target.SetFloat(fValue)
 	default:
 		return newTypeMismatchError(fullName)
 	}
 
+	SetValue(fieldType, value, target)
+
 	return nil
 }
 
@@ -612,7 +602,7 @@ func (u *Unmarshaler) processFieldStruct(fieldType reflect.Type, value reflect.V
 			return err
 		}
 
-		value.Set(target.Addr())
+		SetValue(fieldType, value, target)
 	} else if err := u.unmarshalWithFullName(m, value.Addr().Interface(), fullName); err != nil {
 		return err
 	}
@@ -626,7 +616,13 @@ func (u *Unmarshaler) processFieldTextUnmarshaler(fieldType reflect.Type, value
 	var ok bool
 
 	if fieldType.Kind() == reflect.Ptr {
-		tval, ok = value.Interface().(encoding.TextUnmarshaler)
+		if value.Elem().Kind() == reflect.Ptr {
+			target := reflect.New(Deref(fieldType))
+			SetValue(fieldType.Elem(), value, target)
+			tval, ok = target.Interface().(encoding.TextUnmarshaler)
+		} else {
+			tval, ok = value.Interface().(encoding.TextUnmarshaler)
+		}
 	} else {
 		tval, ok = value.Addr().Interface().(encoding.TextUnmarshaler)
 	}
@@ -659,7 +655,7 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
 		value.SetBool(val)
 		return nil
 	case durationType.Kind():
-		if err := fillDurationValue(fieldKind, value, envVal); err != nil {
+		if err := fillDurationValue(fieldType, value, envVal); err != nil {
 			return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
 		}
 
@@ -773,19 +769,15 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu
 	derefedType := Deref(fieldType)
 	fieldKind := derefedType.Kind()
 	if defaultValue, ok := opts.getDefault(); ok {
-		if fieldType.Kind() == reflect.Ptr {
-			maybeNewValue(fieldType, value)
-			value = value.Elem()
-		}
 		if derefedType == durationType {
-			return fillDurationValue(fieldKind, value, defaultValue)
+			return fillDurationValue(fieldType, value, defaultValue)
 		}
 
 		switch fieldKind {
 		case reflect.Array, reflect.Slice:
 			return u.fillSliceWithDefault(derefedType, value, defaultValue)
 		default:
-			return setValue(fieldKind, value, defaultValue)
+			return setValueFromString(fieldKind, value, defaultValue)
 		}
 	}
 
@@ -870,17 +862,13 @@ func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithP
 	}
 }
 
-func fillDurationValue(fieldKind reflect.Kind, value reflect.Value, dur string) error {
+func fillDurationValue(fieldType reflect.Type, value reflect.Value, dur string) error {
 	d, err := time.ParseDuration(dur)
 	if err != nil {
 		return err
 	}
 
-	if fieldKind == reflect.Ptr {
-		value.Elem().Set(reflect.ValueOf(d))
-	} else {
-		value.Set(reflect.ValueOf(d))
-	}
+	SetValue(fieldType, value, reflect.ValueOf(d))
 
 	return nil
 }
@@ -896,7 +884,7 @@ func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue interfa
 		target := reflect.New(baseType).Elem()
 		switch mapValue.(type) {
 		case string, json.Number:
-			value.Set(target.Addr())
+			SetValue(fieldType, value, target)
 			value = target
 		}
 	}
@@ -908,7 +896,7 @@ func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue interfa
 		if err := validateJsonNumberRange(v, opts); err != nil {
 			return err
 		}
-		return setValue(baseType.Kind(), value, v.String())
+		return setValueFromString(baseType.Kind(), value, v.String())
 	default:
 		return newTypeMismatchError(fullName)
 	}
@@ -928,7 +916,7 @@ func fillWithSameType(fieldType reflect.Type, value reflect.Value, mapValue inte
 		baseType := Deref(fieldType)
 		target := reflect.New(baseType).Elem()
 		setSameKindValue(baseType, target, mapValue)
-		value.Set(target.Addr())
+		SetValue(fieldType, value, target)
 	} else {
 		setSameKindValue(fieldType, value, mapValue)
 	}

文件差异内容过多而无法显示
+ 368 - 223
core/mapping/unmarshaler_test.go


+ 32 - 5
core/mapping/utils.go

@@ -56,7 +56,7 @@ type (
 
 // Deref dereferences a type, if pointer type, returns its element type.
 func Deref(t reflect.Type) reflect.Type {
-	if t.Kind() == reflect.Ptr {
+	for t.Kind() == reflect.Ptr {
 		t = t.Elem()
 	}
 
@@ -68,6 +68,16 @@ func Repr(v interface{}) string {
 	return lang.Repr(v)
 }
 
+// SetValue sets target to value, pointers are processed automatically.
+func SetValue(tp reflect.Type, value, target reflect.Value) {
+	value.Set(convertTypeOfPtr(tp, target))
+}
+
+// SetMapIndexValue sets target to value at key position, pointers are processed automatically.
+func SetMapIndexValue(tp reflect.Type, value, key, target reflect.Value) {
+	value.SetMapIndex(key, convertTypeOfPtr(tp, target))
+}
+
 // ValidatePtr validates v if it's a valid pointer.
 func ValidatePtr(v *reflect.Value) error {
 	// sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr,
@@ -79,7 +89,7 @@ func ValidatePtr(v *reflect.Value) error {
 	return nil
 }
 
-func convertType(kind reflect.Kind, str string) (interface{}, error) {
+func convertTypeFromString(kind reflect.Kind, str string) (interface{}, error) {
 	switch kind {
 	case reflect.Bool:
 		switch strings.ToLower(str) {
@@ -118,6 +128,23 @@ func convertType(kind reflect.Kind, str string) (interface{}, error) {
 	}
 }
 
+func convertTypeOfPtr(tp reflect.Type, target reflect.Value) reflect.Value {
+	// keep the original value is a pointer
+	if tp.Kind() == reflect.Ptr && target.CanAddr() {
+		tp = tp.Elem()
+		target = target.Addr()
+	}
+
+	for tp.Kind() == reflect.Ptr {
+		p := reflect.New(target.Type())
+		p.Elem().Set(target)
+		target = p
+		tp = tp.Elem()
+	}
+
+	return target
+}
+
 func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fieldOptions, error) {
 	segments := parseSegments(value)
 	key := strings.TrimSpace(segments[0])
@@ -476,13 +503,13 @@ func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interfac
 	return nil
 }
 
-func setValue(kind reflect.Kind, value reflect.Value, str string) error {
+func setValueFromString(kind reflect.Kind, value reflect.Value, str string) error {
 	if !value.CanSet() {
 		return errValueNotSettable
 	}
 
 	value = ensureValue(value)
-	v, err := convertType(kind, str)
+	v, err := convertTypeFromString(kind, str)
 	if err != nil {
 		return err
 	}
@@ -555,7 +582,7 @@ func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string, opt
 		return errValueNotSettable
 	}
 
-	v, err := convertType(kind, str)
+	v, err := convertTypeFromString(kind, str)
 	if err != nil {
 		return err
 	}

+ 2 - 2
core/mapping/utils_test.go

@@ -237,7 +237,7 @@ func TestValidatePtrWithZeroValue(t *testing.T) {
 
 func TestSetValueNotSettable(t *testing.T) {
 	var i int
-	assert.NotNil(t, setValue(reflect.Int, reflect.ValueOf(i), "1"))
+	assert.NotNil(t, setValueFromString(reflect.Int, reflect.ValueOf(i), "1"))
 }
 
 func TestParseKeyAndOptionsErrors(t *testing.T) {
@@ -290,7 +290,7 @@ func TestSetValueFormatErrors(t *testing.T) {
 
 	for _, test := range tests {
 		t.Run(test.kind.String(), func(t *testing.T) {
-			err := setValue(test.kind, test.target, test.value)
+			err := setValueFromString(test.kind, test.target, test.value)
 			assert.NotEqual(t, errValueNotSettable, err)
 			assert.NotNil(t, err)
 		})

+ 2 - 1
core/stores/mon/trace.go

@@ -14,12 +14,13 @@ import (
 var mongoCmdAttributeKey = attribute.Key("mongo.cmd")
 
 func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span) {
-	tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
+	tracer := otel.Tracer(trace.TraceName)
 	ctx, span := tracer.Start(ctx,
 		spanName,
 		oteltrace.WithSpanKind(oteltrace.SpanKindClient),
 	)
 	span.SetAttributes(mongoCmdAttributeKey.String(cmd))
+
 	return ctx, span
 }
 

+ 1 - 1
core/stores/redis/hook.go

@@ -25,7 +25,7 @@ const spanName = "redis"
 
 var (
 	startTimeKey          = contextKey("startTime")
-	durationHook          = hook{tracer: otel.GetTracerProvider().Tracer(trace.TraceName)}
+	durationHook          = hook{tracer: otel.Tracer(trace.TraceName)}
 	redisCmdsAttributeKey = attribute.Key("redis.cmds")
 )
 

+ 1 - 1
core/stores/sqlx/trace.go

@@ -14,7 +14,7 @@ import (
 var sqlAttributeKey = attribute.Key("sql.method")
 
 func startSpan(ctx context.Context, method string) (context.Context, oteltrace.Span) {
-	tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
+	tracer := otel.Tracer(trace.TraceName)
 	start, span := tracer.Start(ctx,
 		spanName,
 		oteltrace.WithSpanKind(oteltrace.SpanKindClient),

+ 1 - 1
rest/handler/tracinghandler.go

@@ -33,8 +33,8 @@ func TracingHandler(serviceName, path string, opts ...TracingOption) func(http.H
 	ignorePaths.AddStr(tracingOpts.traceIgnorePaths...)
 	traceHandler := func(checkIgnore bool) func(http.Handler) http.Handler {
 		return func(next http.Handler) http.Handler {
+			tracer := otel.Tracer(trace.TraceName)
 			propagator := otel.GetTextMapPropagator()
-			tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
 
 			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 				spanName := path

+ 1 - 1
rest/httpc/requests.go

@@ -156,7 +156,7 @@ func fillPath(u *nurl.URL, val map[string]interface{}) error {
 }
 
 func request(r *http.Request, cli client) (*http.Response, error) {
-	tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
+	tracer := otel.Tracer(trace.TraceName)
 	propagator := otel.GetTextMapPropagator()
 
 	spanName := r.URL.Path

部分文件因为文件数量过多而无法显示