Przeglądaj źródła

feat: support form values in gateway (#2158)

Kevin Wan 2 lat temu
rodzic
commit
b206dd28a3

+ 15 - 5
gateway/requestparser.go

@@ -7,12 +7,14 @@ import (
 
 	"github.com/fullstorydev/grpcurl"
 	"github.com/golang/protobuf/jsonpb"
+	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/pathvar"
 )
 
-func buildJsonRequestParser(v interface{}, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) {
+func buildJsonRequestParser(m map[string]interface{}, resolver jsonpb.AnyResolver) (
+	grpcurl.RequestParser, error) {
 	var buf bytes.Buffer
-	if err := json.NewEncoder(&buf).Encode(v); err != nil {
+	if err := json.NewEncoder(&buf).Encode(m); err != nil {
 		return nil, err
 	}
 
@@ -21,12 +23,20 @@ func buildJsonRequestParser(v interface{}, resolver jsonpb.AnyResolver) (grpcurl
 
 func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) {
 	vars := pathvar.Vars(r)
-	if len(vars) == 0 {
+	params, err := httpx.GetFormValues(r)
+	if err != nil {
+		return nil, err
+	}
+
+	for k, v := range vars {
+		params[k] = v
+	}
+	if len(params) == 0 {
 		return grpcurl.NewJSONRequestParser(r.Body, resolver), nil
 	}
 
 	if r.ContentLength == 0 {
-		return buildJsonRequestParser(vars, resolver)
+		return buildJsonRequestParser(params, resolver)
 	}
 
 	m := make(map[string]interface{})
@@ -34,7 +44,7 @@ func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.Req
 		return nil, err
 	}
 
-	for k, v := range vars {
+	for k, v := range params {
 		m[k] = v
 	}
 

+ 7 - 0
gateway/requestparser_test.go

@@ -46,3 +46,10 @@ func TestNewRequestParserWithVarsWithWrongBody(t *testing.T) {
 	assert.NotNil(t, err)
 	assert.Nil(t, parser)
 }
+
+func TestNewRequestParserWithForm(t *testing.T) {
+	req := httptest.NewRequest("GET", "/val?a=b", nil)
+	parser, err := newRequestParser(req, nil)
+	assert.Nil(t, err)
+	assert.NotNil(t, parser)
+}

+ 2 - 15
rest/httpx/requests.go

@@ -49,24 +49,11 @@ func ParseHeaders(r *http.Request, v interface{}) error {
 
 // ParseForm parses the form request.
 func ParseForm(r *http.Request, v interface{}) error {
-	if err := r.ParseForm(); err != nil {
+	params, err := GetFormValues(r)
+	if err != nil {
 		return err
 	}
 
-	if err := r.ParseMultipartForm(maxMemory); err != nil {
-		if err != http.ErrNotMultipart {
-			return err
-		}
-	}
-
-	params := make(map[string]interface{}, len(r.Form))
-	for name := range r.Form {
-		formValue := r.Form.Get(name)
-		if len(formValue) > 0 {
-			params[name] = formValue
-		}
-	}
-
 	return formUnmarshaler.Unmarshal(params, v)
 }
 

+ 23 - 0
rest/httpx/util.go

@@ -4,6 +4,29 @@ import "net/http"
 
 const xForwardedFor = "X-Forwarded-For"
 
+// GetFormValues returns the form values.
+func GetFormValues(r *http.Request) (map[string]interface{}, error) {
+	if err := r.ParseForm(); err != nil {
+		return nil, err
+	}
+
+	if err := r.ParseMultipartForm(maxMemory); err != nil {
+		if err != http.ErrNotMultipart {
+			return nil, err
+		}
+	}
+
+	params := make(map[string]interface{}, len(r.Form))
+	for name := range r.Form {
+		formValue := r.Form.Get(name)
+		if len(formValue) > 0 {
+			params[name] = formValue
+		}
+	}
+
+	return params, nil
+}
+
 // GetRemoteAddr returns the peer address, supports X-Forward-For.
 func GetRemoteAddr(r *http.Request) string {
 	v := r.Header.Get(xForwardedFor)