Explorar el Código

feat: add httpc.Parse (#1698)

Kevin Wan hace 3 años
padre
commit
c1d9e6a00b

+ 2 - 0
rest/httpc/internal/loginterceptor_test.go

@@ -11,6 +11,7 @@ import (
 func TestLogInterceptor(t *testing.T) {
 	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	}))
+	defer svr.Close()
 	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
 	assert.Nil(t, err)
 	req, handler := LogInterceptor(req)
@@ -24,6 +25,7 @@ func TestLogInterceptorServerError(t *testing.T) {
 	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		w.WriteHeader(http.StatusInternalServerError)
 	}))
+	defer svr.Close()
 	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
 	assert.Nil(t, err)
 	req, handler := LogInterceptor(req)

+ 3 - 0
rest/httpc/requests_test.go

@@ -11,6 +11,7 @@ import (
 func TestDo(t *testing.T) {
 	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	}))
+	defer svr.Close()
 	_, err := Get("foo", "tcp://bad request")
 	assert.NotNil(t, err)
 	resp, err := Get("foo", svr.URL)
@@ -20,6 +21,7 @@ func TestDo(t *testing.T) {
 
 func TestDoNotFound(t *testing.T) {
 	svr := httptest.NewServer(http.NotFoundHandler())
+	defer svr.Close()
 	_, err := Post("foo", "tcp://bad request", "application/json", nil)
 	assert.NotNil(t, err)
 	resp, err := Post("foo", svr.URL, "application/json", nil)
@@ -29,6 +31,7 @@ func TestDoNotFound(t *testing.T) {
 
 func TestDoMoved(t *testing.T) {
 	svr := httptest.NewServer(http.RedirectHandler("/foo", http.StatusMovedPermanently))
+	defer svr.Close()
 	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
 	assert.Nil(t, err)
 	_, err = Do("foo", req)

+ 33 - 0
rest/httpc/responses.go

@@ -0,0 +1,33 @@
+package httpc
+
+import (
+	"net/http"
+	"strings"
+
+	"github.com/zeromicro/go-zero/core/mapping"
+	"github.com/zeromicro/go-zero/rest/internal/encoding"
+)
+
+func Parse(resp *http.Response, val interface{}) error {
+	if err := ParseHeaders(resp, val); err != nil {
+		return err
+	}
+
+	return ParseJsonBody(resp, val)
+}
+
+func ParseHeaders(resp *http.Response, val interface{}) error {
+	return encoding.ParseHeaders(resp.Header, val)
+}
+
+func ParseJsonBody(resp *http.Response, val interface{}) error {
+	if withJsonBody(resp) {
+		return mapping.UnmarshalJsonReader(resp.Body, val)
+	}
+
+	return mapping.UnmarshalJsonMap(nil, val)
+}
+
+func withJsonBody(r *http.Response) bool {
+	return r.ContentLength > 0 && strings.Contains(r.Header.Get(contentType), applicationJson)
+}

+ 58 - 0
rest/httpc/responses_test.go

@@ -0,0 +1,58 @@
+package httpc
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestParse(t *testing.T) {
+	var val struct {
+		Foo   string `header:"foo"`
+		Name  string `json:"name"`
+		Value int    `json:"value"`
+	}
+	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set("foo", "bar")
+		w.Header().Set(contentType, applicationJson)
+		w.Write([]byte(`{"name":"kevin","value":100}`))
+	}))
+	defer svr.Close()
+	resp, err := Get("foo", svr.URL)
+	assert.Nil(t, err)
+	assert.Nil(t, Parse(resp, &val))
+	assert.Equal(t, "bar", val.Foo)
+	assert.Equal(t, "kevin", val.Name)
+	assert.Equal(t, 100, val.Value)
+}
+
+func TestParseHeaderError(t *testing.T) {
+	var val struct {
+		Foo int `header:"foo"`
+	}
+	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set("foo", "bar")
+		w.Header().Set(contentType, applicationJson)
+	}))
+	defer svr.Close()
+	resp, err := Get("foo", svr.URL)
+	assert.Nil(t, err)
+	assert.NotNil(t, Parse(resp, &val))
+}
+
+func TestParseNoBody(t *testing.T) {
+	var val struct {
+		Foo string `header:"foo"`
+	}
+	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set("foo", "bar")
+		w.Header().Set(contentType, applicationJson)
+	}))
+	defer svr.Close()
+	resp, err := Get("foo", svr.URL)
+	assert.Nil(t, err)
+	assert.Nil(t, Parse(resp, &val))
+	assert.Equal(t, "bar", val.Foo)
+}

+ 2 - 5
rest/httpc/service.go

@@ -9,9 +9,6 @@ import (
 	"github.com/zeromicro/go-zero/rest/httpc/internal"
 )
 
-// ContentType means Content-Type.
-const ContentType = "Content-Type"
-
 var interceptors = []internal.Interceptor{
 	internal.LogInterceptor,
 }
@@ -86,13 +83,13 @@ func (s namedService) Get(url string) (*http.Response, error) {
 }
 
 // Post sends an HTTP POST request to the service.
-func (s namedService) Post(url, contentType string, body io.Reader) (*http.Response, error) {
+func (s namedService) Post(url, ctype string, body io.Reader) (*http.Response, error) {
 	r, err := http.NewRequest(http.MethodPost, url, body)
 	if err != nil {
 		return nil, err
 	}
 
-	r.Header.Set(ContentType, contentType)
+	r.Header.Set(contentType, ctype)
 	return s.Do(r)
 }
 

+ 3 - 0
rest/httpc/service_test.go

@@ -10,6 +10,7 @@ import (
 
 func TestNamedService_Do(t *testing.T) {
 	svr := httptest.NewServer(http.RedirectHandler("/foo", http.StatusMovedPermanently))
+	defer svr.Close()
 	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
 	assert.Nil(t, err)
 	service := NewService("foo")
@@ -22,6 +23,7 @@ func TestNamedService_Get(t *testing.T) {
 	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		w.Header().Set("foo", r.Header.Get("foo"))
 	}))
+	defer svr.Close()
 	service := NewService("foo", func(r *http.Request) *http.Request {
 		r.Header.Set("foo", "bar")
 		return r
@@ -34,6 +36,7 @@ func TestNamedService_Get(t *testing.T) {
 
 func TestNamedService_Post(t *testing.T) {
 	svr := httptest.NewServer(http.NotFoundHandler())
+	defer svr.Close()
 	service := NewService("foo")
 	_, err := service.Post("tcp://bad request", "application/json", nil)
 	assert.NotNil(t, err)

+ 6 - 0
rest/httpc/vars.go

@@ -0,0 +1,6 @@
+package httpc
+
+const (
+	contentType     = "Content-Type"
+	applicationJson = "application/json"
+)

+ 4 - 16
rest/httpx/requests.go

@@ -3,17 +3,16 @@ package httpx
 import (
 	"io"
 	"net/http"
-	"net/textproto"
 	"strings"
 
 	"github.com/zeromicro/go-zero/core/mapping"
+	"github.com/zeromicro/go-zero/rest/internal/encoding"
 	"github.com/zeromicro/go-zero/rest/pathvar"
 )
 
 const (
 	formKey           = "form"
 	pathKey           = "path"
-	headerKey         = "header"
 	maxMemory         = 32 << 20 // 32MB
 	maxBodyLen        = 8 << 20  // 8MB
 	separator         = ";"
@@ -21,10 +20,8 @@ const (
 )
 
 var (
-	formUnmarshaler   = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
-	pathUnmarshaler   = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
-	headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(),
-		mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey))
+	formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
+	pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
 )
 
 // Parse parses the request.
@@ -46,16 +43,7 @@ func Parse(r *http.Request, v interface{}) error {
 
 // ParseHeaders parses the headers request.
 func ParseHeaders(r *http.Request, v interface{}) error {
-	m := map[string]interface{}{}
-	for k, v := range r.Header {
-		if len(v) == 1 {
-			m[k] = v[0]
-		} else {
-			m[k] = v
-		}
-	}
-
-	return headerUnmarshaler.Unmarshal(m, v)
+	return encoding.ParseHeaders(r.Header, v)
 }
 
 // ParseForm parses the form request.

+ 27 - 0
rest/internal/encoding/parser.go

@@ -0,0 +1,27 @@
+package encoding
+
+import (
+	"net/http"
+	"net/textproto"
+
+	"github.com/zeromicro/go-zero/core/mapping"
+)
+
+const headerKey = "header"
+
+var headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(),
+	mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey))
+
+// ParseHeaders parses the headers request.
+func ParseHeaders(header http.Header, v interface{}) error {
+	m := map[string]interface{}{}
+	for k, v := range header {
+		if len(v) == 1 {
+			m[k] = v[0]
+		} else {
+			m[k] = v
+		}
+	}
+
+	return headerUnmarshaler.Unmarshal(m, v)
+}

+ 40 - 0
rest/internal/encoding/parser_test.go

@@ -0,0 +1,40 @@
+package encoding
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestParseHeaders(t *testing.T) {
+	var val struct {
+		Foo string `header:"foo"`
+		Baz int    `header:"baz"`
+		Qux bool   `header:"qux,default=true"`
+	}
+	r := httptest.NewRequest(http.MethodGet, "/any", nil)
+	r.Header.Set("foo", "bar")
+	r.Header.Set("baz", "1")
+	assert.Nil(t, ParseHeaders(r.Header, &val))
+	assert.Equal(t, "bar", val.Foo)
+	assert.Equal(t, 1, val.Baz)
+	assert.True(t, val.Qux)
+}
+
+func TestParseHeadersMulti(t *testing.T) {
+	var val struct {
+		Foo []string `header:"foo"`
+		Baz int      `header:"baz"`
+		Qux bool     `header:"qux,default=true"`
+	}
+	r := httptest.NewRequest(http.MethodGet, "/any", nil)
+	r.Header.Set("foo", "bar")
+	r.Header.Add("foo", "bar1")
+	r.Header.Set("baz", "1")
+	assert.Nil(t, ParseHeaders(r.Header, &val))
+	assert.Equal(t, []string{"bar", "bar1"}, val.Foo)
+	assert.Equal(t, 1, val.Baz)
+	assert.True(t, val.Qux)
+}