Browse Source

fix http header binding failure bug #885 (#887)

voidint 3 năm trước cách đây
mục cha
commit
28a7c9d38f
3 tập tin đã thay đổi với 30 bổ sung8 xóa
  1. 13 3
      core/mapping/unmarshaler.go
  2. 2 2
      rest/httpx/requests.go
  3. 15 3
      rest/httpx/requests_test.go

+ 13 - 3
core/mapping/unmarshaler.go

@@ -43,7 +43,8 @@ type (
 	UnmarshalOption func(*unmarshalOptions)
 
 	unmarshalOptions struct {
-		fromString bool
+		fromString   bool
+		canonicalKey func(key string) string
 	}
 
 	keyCache map[string][]string
@@ -321,9 +322,12 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
 	if err != nil {
 		return err
 	}
-
+	k := key
+	if u.opts.canonicalKey != nil {
+		k = u.opts.canonicalKey(key)
+	}
 	fullName = join(fullName, key)
-	mapValue, hasValue := getValue(m, key)
+	mapValue, hasValue := getValue(m, k)
 	if hasValue {
 		return u.processNamedFieldWithValue(field, value, mapValue, key, opts, fullName)
 	}
@@ -621,6 +625,12 @@ func WithStringValues() UnmarshalOption {
 	}
 }
 
+func WithCanonicalKeyFunc(f func(string) string) UnmarshalOption {
+	return func(opt *unmarshalOptions) {
+		opt.canonicalKey = f
+	}
+}
+
 func fillDurationValue(fieldKind reflect.Kind, value reflect.Value, dur string) error {
 	d, err := time.ParseDuration(dur)
 	if err != nil {

+ 2 - 2
rest/httpx/requests.go

@@ -3,6 +3,7 @@ package httpx
 import (
 	"io"
 	"net/http"
+	"net/textproto"
 	"strings"
 
 	"github.com/tal-tech/go-zero/core/mapping"
@@ -23,7 +24,7 @@ const (
 var (
 	formUnmarshaler   = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
 	pathUnmarshaler   = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
-	headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues())
+	headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(), mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey))
 )
 
 // Parse parses the request.
@@ -47,7 +48,6 @@ func Parse(r *http.Request, v interface{}) error {
 func ParseHeaders(r *http.Request, v interface{}) error {
 	m := map[string]interface{}{}
 	for k, v := range r.Header {
-		k = strings.ToLower(k)
 		if len(v) == 1 {
 			m[k] = v[0]
 		} else {

+ 15 - 3
rest/httpx/requests_test.go

@@ -203,10 +203,16 @@ func BenchmarkParseAuto(b *testing.B) {
 }
 
 func TestParseHeaders(t *testing.T) {
+	type AnonymousStruct struct {
+		XRealIP string `header:"x-real-ip"`
+		Accept  string `header:"Accept,optional"`
+	}
 	v := struct {
-		Name    string   `header:"name"`
-		Percent string   `header:"percent"`
-		Addrs   []string `header:"addrs"`
+		Name          string   `header:"name,optional"`
+		Percent       string   `header:"percent"`
+		Addrs         []string `header:"addrs"`
+		XForwardedFor string   `header:"X-Forwarded-For,optional"`
+		AnonymousStruct
 	}{}
 	request, err := http.NewRequest("POST", "http://hello.com/", nil)
 	if err != nil {
@@ -216,6 +222,9 @@ func TestParseHeaders(t *testing.T) {
 	request.Header.Set("percent", "1")
 	request.Header.Add("addrs", "addr1")
 	request.Header.Add("addrs", "addr2")
+	request.Header.Add("X-Forwarded-For", "10.0.10.11")
+	request.Header.Add("x-real-ip", "10.0.11.10")
+	request.Header.Add("Accept", "application/json")
 	err = ParseHeaders(request, &v)
 	if err != nil {
 		t.Fatal(err)
@@ -223,4 +232,7 @@ func TestParseHeaders(t *testing.T) {
 	assert.Equal(t, "chenquan", v.Name)
 	assert.Equal(t, "1", v.Percent)
 	assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs)
+	assert.Equal(t, "10.0.10.11", v.XForwardedFor)
+	assert.Equal(t, "10.0.11.10", v.XRealIP)
+	assert.Equal(t, "application/json", v.Accept)
 }