Selaa lähdekoodia

feat: support **struct in mapping (#2784)

* feat: support **struct in mapping

* chore: fix test failure
Kevin Wan 2 vuotta sitten
vanhempi
sitoutus
4d7fa08b0b

+ 13 - 6
core/mapping/unmarshaler.go

@@ -77,7 +77,7 @@ func (u *Unmarshaler) Unmarshal(i interface{}, v interface{}) error {
 		return errValueNotSettable
 	}
 
-	elemType := valueType.Elem()
+	elemType := Deref(valueType)
 	switch iv := i.(type) {
 	case map[string]interface{}:
 		if elemType.Kind() != reflect.Struct {
@@ -818,15 +818,22 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v interface{}, f
 		return err
 	}
 
-	rte := reflect.TypeOf(v).Elem()
-	if rte.Kind() != reflect.Struct {
+	valueType := reflect.TypeOf(v)
+	baseType := Deref(valueType)
+	if baseType.Kind() != reflect.Struct {
 		return errValueNotStruct
 	}
 
-	rve := rv.Elem()
-	numFields := rte.NumField()
+	valElem := rv.Elem()
+	if valElem.Kind() == reflect.Ptr {
+		target := reflect.New(baseType).Elem()
+		SetValue(valueType.Elem(), valElem, target)
+		valElem = target
+	}
+
+	numFields := baseType.NumField()
 	for i := 0; i < numFields; i++ {
-		if err := u.processField(rte.Field(i), rve.Field(i), m, fullName); err != nil {
+		if err := u.processField(baseType.Field(i), valElem.Field(i), m, fullName); err != nil {
 			return err
 		}
 	}

+ 83 - 25
core/mapping/unmarshaler_test.go

@@ -3,6 +3,7 @@ package mapping
 import (
 	"encoding/json"
 	"fmt"
+	"os"
 	"strconv"
 	"strings"
 	"testing"
@@ -3388,7 +3389,8 @@ func TestUnmarshal_EnvString(t *testing.T) {
 		envName = "TEST_NAME_STRING"
 		envVal  = "this is a name"
 	)
-	t.Setenv(envName, envVal)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3405,7 +3407,8 @@ func TestUnmarshal_EnvStringOverwrite(t *testing.T) {
 		envName = "TEST_NAME_STRING"
 		envVal  = "this is a name"
 	)
-	t.Setenv(envName, envVal)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(map[string]interface{}{
@@ -3420,8 +3423,12 @@ func TestUnmarshal_EnvInt(t *testing.T) {
 		Age int `key:"age,env=TEST_NAME_INT"`
 	}
 
-	const envName = "TEST_NAME_INT"
-	t.Setenv(envName, "123")
+	const (
+		envName = "TEST_NAME_INT"
+		envVal  = "123"
+	)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3434,8 +3441,12 @@ func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
 		Age int `key:"age,env=TEST_NAME_INT"`
 	}
 
-	const envName = "TEST_NAME_INT"
-	t.Setenv(envName, "123")
+	const (
+		envName = "TEST_NAME_INT"
+		envVal  = "123"
+	)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(map[string]interface{}{
@@ -3450,8 +3461,12 @@ func TestUnmarshal_EnvFloat(t *testing.T) {
 		Age float32 `key:"name,env=TEST_NAME_FLOAT"`
 	}
 
-	const envName = "TEST_NAME_FLOAT"
-	t.Setenv(envName, "123.45")
+	const (
+		envName = "TEST_NAME_FLOAT"
+		envVal  = "123.45"
+	)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3464,8 +3479,12 @@ func TestUnmarshal_EnvFloatOverwrite(t *testing.T) {
 		Age float32 `key:"age,env=TEST_NAME_FLOAT"`
 	}
 
-	const envName = "TEST_NAME_FLOAT"
-	t.Setenv(envName, "123.45")
+	const (
+		envName = "TEST_NAME_FLOAT"
+		envVal  = "123.45"
+	)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(map[string]interface{}{
@@ -3480,8 +3499,12 @@ func TestUnmarshal_EnvBoolTrue(t *testing.T) {
 		Enable bool `key:"enable,env=TEST_NAME_BOOL_TRUE"`
 	}
 
-	const envName = "TEST_NAME_BOOL_TRUE"
-	t.Setenv(envName, "true")
+	const (
+		envName = "TEST_NAME_BOOL_TRUE"
+		envVal  = "true"
+	)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3494,8 +3517,12 @@ func TestUnmarshal_EnvBoolFalse(t *testing.T) {
 		Enable bool `key:"enable,env=TEST_NAME_BOOL_FALSE"`
 	}
 
-	const envName = "TEST_NAME_BOOL_FALSE"
-	t.Setenv(envName, "false")
+	const (
+		envName = "TEST_NAME_BOOL_FALSE"
+		envVal  = "false"
+	)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3508,8 +3535,12 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
 		Enable bool `key:"enable,env=TEST_NAME_BOOL_BAD"`
 	}
 
-	const envName = "TEST_NAME_BOOL_BAD"
-	t.Setenv(envName, "bad")
+	const (
+		envName = "TEST_NAME_BOOL_BAD"
+		envVal  = "bad"
+	)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -3520,8 +3551,12 @@ func TestUnmarshal_EnvDuration(t *testing.T) {
 		Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
 	}
 
-	const envName = "TEST_NAME_DURATION"
-	t.Setenv(envName, "1s")
+	const (
+		envName = "TEST_NAME_DURATION"
+		envVal  = "1s"
+	)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3534,8 +3569,12 @@ func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
 		Duration time.Duration `key:"duration,env=TEST_NAME_BAD_DURATION"`
 	}
 
-	const envName = "TEST_NAME_BAD_DURATION"
-	t.Setenv(envName, "bad")
+	const (
+		envName = "TEST_NAME_BAD_DURATION"
+		envVal  = "bad"
+	)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -3550,7 +3589,8 @@ func TestUnmarshal_EnvWithOptions(t *testing.T) {
 		envName = "TEST_NAME_ENV_OPTIONS_MATCH"
 		envVal  = "123"
 	)
-	t.Setenv(envName, envVal)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3567,7 +3607,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueBool(t *testing.T) {
 		envName = "TEST_NAME_ENV_OPTIONS_BOOL"
 		envVal  = "false"
 	)
-	t.Setenv(envName, envVal)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -3582,7 +3623,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueDuration(t *testing.T) {
 		envName = "TEST_NAME_ENV_OPTIONS_DURATION"
 		envVal  = "4s"
 	)
-	t.Setenv(envName, envVal)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -3597,7 +3639,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueNumber(t *testing.T) {
 		envName = "TEST_NAME_ENV_OPTIONS_AGE"
 		envVal  = "30"
 	)
-	t.Setenv(envName, envVal)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -3612,7 +3655,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueString(t *testing.T) {
 		envName = "TEST_NAME_ENV_OPTIONS_STRING"
 		envVal  = "this is a name"
 	)
-	t.Setenv(envName, envVal)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -4115,6 +4159,20 @@ func TestUnmarshalNestedPtr(t *testing.T) {
 	}
 }
 
+func TestUnmarshalStructPtrOfPtr(t *testing.T) {
+	type inner struct {
+		Int int `key:"int"`
+	}
+	m := map[string]interface{}{
+		"int": 1,
+	}
+
+	in := new(inner)
+	if assert.NoError(t, UnmarshalKey(m, &in)) {
+		assert.Equal(t, 1, in.Int)
+	}
+}
+
 func BenchmarkDefaultValue(b *testing.B) {
 	for i := 0; i < b.N; i++ {
 		var a struct {

+ 2 - 2
rest/engine.go

@@ -118,7 +118,7 @@ func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route
 	chn := chain.New()
 
 	if ng.conf.Middlewares.Trace {
-		chn = chn.Append(handler.TracingHandler(ng.conf.Name,
+		chn = chn.Append(handler.TraceHandler(ng.conf.Name,
 			route.Path,
 			handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)))
 	}
@@ -204,7 +204,7 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
 func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		chn := chain.New(
-			handler.TracingHandler(ng.conf.Name,
+			handler.TraceHandler(ng.conf.Name,
 				"",
 				handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)),
 			ng.getLogHandler(),

+ 78 - 0
rest/handler/tracehandler.go

@@ -0,0 +1,78 @@
+package handler
+
+import (
+	"net/http"
+
+	"github.com/zeromicro/go-zero/core/collection"
+	"github.com/zeromicro/go-zero/core/trace"
+	"github.com/zeromicro/go-zero/rest/internal/response"
+	"go.opentelemetry.io/otel"
+	"go.opentelemetry.io/otel/propagation"
+	semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
+	oteltrace "go.opentelemetry.io/otel/trace"
+)
+
+type (
+	// TraceOption defines the method to customize an traceOptions.
+	TraceOption func(options *traceOptions)
+
+	// traceOptions is TraceHandler options.
+	traceOptions struct {
+		traceIgnorePaths []string
+	}
+)
+
+// TraceHandler return a middleware that process the opentelemetry.
+func TraceHandler(serviceName, path string, opts ...TraceOption) func(http.Handler) http.Handler {
+	var options traceOptions
+	for _, opt := range opts {
+		opt(&options)
+	}
+
+	ignorePaths := collection.NewSet()
+	ignorePaths.AddStr(options.traceIgnorePaths...)
+
+	return func(next http.Handler) http.Handler {
+		tracer := otel.Tracer(trace.TraceName)
+		propagator := otel.GetTextMapPropagator()
+
+		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			spanName := path
+			if len(spanName) == 0 {
+				spanName = r.URL.Path
+			}
+
+			if ignorePaths.Contains(spanName) {
+				next.ServeHTTP(w, r)
+				return
+			}
+
+			ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
+			spanCtx, span := tracer.Start(
+				ctx,
+				spanName,
+				oteltrace.WithSpanKind(oteltrace.SpanKindServer),
+				oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(
+					serviceName, spanName, r)...),
+			)
+			defer span.End()
+
+			// convenient for tracking error messages
+			propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
+
+			trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK}
+			next.ServeHTTP(trw, r.WithContext(spanCtx))
+
+			span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...)
+			span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(
+				trw.Code, oteltrace.SpanKindServer))
+		})
+	}
+}
+
+// WithTraceIgnorePaths specifies the traceIgnorePaths option for TraceHandler.
+func WithTraceIgnorePaths(traceIgnorePaths []string) TraceOption {
+	return func(options *traceOptions) {
+		options.traceIgnorePaths = append(options.traceIgnorePaths, traceIgnorePaths...)
+	}
+}

+ 3 - 3
rest/handler/tracinghandler_test.go → rest/handler/tracehandler_test.go

@@ -27,7 +27,7 @@ func TestOtelHandler(t *testing.T) {
 
 	for _, test := range []string{"", "bar"} {
 		t.Run(test, func(t *testing.T) {
-			h := chain.New(TracingHandler("foo", test)).Then(
+			h := chain.New(TraceHandler("foo", test)).Then(
 				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 					span := trace.SpanFromContext(r.Context())
 					assert.True(t, span.SpanContext().IsValid())
@@ -65,7 +65,7 @@ func TestDontTracingSpan(t *testing.T) {
 
 	for _, test := range []string{"", "bar", "foo"} {
 		t.Run(test, func(t *testing.T) {
-			h := chain.New(TracingHandler("foo", test, WithTraceIgnorePaths([]string{"bar"}))).Then(
+			h := chain.New(TraceHandler("foo", test, WithTraceIgnorePaths([]string{"bar"}))).Then(
 				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 					span := trace.SpanFromContext(r.Context())
 					spanCtx := span.SpanContext()
@@ -110,7 +110,7 @@ func TestTraceResponseWriter(t *testing.T) {
 
 	for _, test := range []int{0, 200, 300, 400, 401, 500, 503} {
 		t.Run(strconv.Itoa(test), func(t *testing.T) {
-			h := chain.New(TracingHandler("foo", "bar")).Then(
+			h := chain.New(TraceHandler("foo", "bar")).Then(
 				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 					span := trace.SpanFromContext(r.Context())
 					spanCtx := span.SpanContext()

+ 0 - 80
rest/handler/tracinghandler.go

@@ -1,80 +0,0 @@
-package handler
-
-import (
-	"net/http"
-
-	"github.com/zeromicro/go-zero/core/collection"
-	"github.com/zeromicro/go-zero/core/trace"
-	"github.com/zeromicro/go-zero/rest/internal/response"
-	"go.opentelemetry.io/otel"
-	"go.opentelemetry.io/otel/propagation"
-	semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
-	oteltrace "go.opentelemetry.io/otel/trace"
-)
-
-type (
-	// TracingOption defines the method to customize an tracingOptions.
-	TracingOption func(options *tracingOptions)
-
-	// tracingOptions is TracingHandler options.
-	tracingOptions struct {
-		traceIgnorePaths []string
-	}
-)
-
-// TracingHandler return a middleware that process the opentelemetry.
-func TracingHandler(serviceName, path string, opts ...TracingOption) func(http.Handler) http.Handler {
-	var tracingOpts tracingOptions
-	for _, opt := range opts {
-		opt(&tracingOpts)
-	}
-
-	ignorePaths := collection.NewSet()
-	ignorePaths.AddStr(tracingOpts.traceIgnorePaths...)
-	traceHandler := func(checkIgnore bool) func(http.Handler) http.Handler {
-		return func(next http.Handler) http.Handler {
-			tracer := otel.Tracer(trace.TraceName)
-			propagator := otel.GetTextMapPropagator()
-
-			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-				spanName := path
-				if len(spanName) == 0 {
-					spanName = r.URL.Path
-				}
-
-				if checkIgnore && ignorePaths.Contains(spanName) {
-					next.ServeHTTP(w, r)
-					return
-				}
-
-				ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
-				spanCtx, span := tracer.Start(
-					ctx,
-					spanName,
-					oteltrace.WithSpanKind(oteltrace.SpanKindServer),
-					oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(
-						serviceName, spanName, r)...),
-				)
-				defer span.End()
-
-				// convenient for tracking error messages
-				propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
-
-				trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK}
-				next.ServeHTTP(trw, r.WithContext(spanCtx))
-
-				span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...)
-				span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(trw.Code, oteltrace.SpanKindServer))
-			})
-		}
-	}
-	checkIgnore := ignorePaths.Count() > 0
-	return traceHandler(checkIgnore)
-}
-
-// WithTraceIgnorePaths specifies the traceIgnorePaths option for TracingHandler.
-func WithTraceIgnorePaths(traceIgnorePaths []string) TracingOption {
-	return func(options *tracingOptions) {
-		options.traceIgnorePaths = append(options.traceIgnorePaths, traceIgnorePaths...)
-	}
-}