Selaa lähdekoodia

fix: handle the scenarios that content-length is invalid (#2313)

Kevin Wan 2 vuotta sitten
vanhempi
sitoutus
5d00dfb962
2 muutettua tiedostoa jossa 112 lisäystä ja 4 poistoa
  1. 17 4
      rest/httpc/responses.go
  2. 95 0
      rest/httpc/responses_test.go

+ 17 - 4
rest/httpc/responses.go

@@ -1,6 +1,8 @@
 package httpc
 
 import (
+	"bytes"
+	"io"
 	"net/http"
 	"strings"
 
@@ -27,13 +29,24 @@ func ParseHeaders(resp *http.Response, val interface{}) error {
 func ParseJsonBody(resp *http.Response, val interface{}) error {
 	defer resp.Body.Close()
 
-	if withJsonBody(resp) {
-		return mapping.UnmarshalJsonReader(resp.Body, val)
+	if isContentTypeJson(resp) {
+		if resp.ContentLength > 0 {
+			return mapping.UnmarshalJsonReader(resp.Body, val)
+		}
+
+		var buf bytes.Buffer
+		if _, err := io.Copy(&buf, resp.Body); err != nil {
+			return err
+		}
+
+		if buf.Len() > 0 {
+			return mapping.UnmarshalJsonReader(&buf, val)
+		}
 	}
 
 	return mapping.UnmarshalJsonMap(nil, val)
 }
 
-func withJsonBody(r *http.Response) bool {
-	return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson)
+func isContentTypeJson(r *http.Response) bool {
+	return strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson)
 }

+ 95 - 0
rest/httpc/responses_test.go

@@ -1,6 +1,7 @@
 package httpc
 
 import (
+	"errors"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -83,3 +84,97 @@ func TestParseWithZeroValue(t *testing.T) {
 	assert.Equal(t, 0, val.Foo)
 	assert.Equal(t, 0, val.Bar)
 }
+
+func TestParseWithNegativeContentLength(t *testing.T) {
+	var val struct {
+		Bar int `json:"bar"`
+	}
+	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set(header.ContentType, header.JsonContentType)
+		w.Write([]byte(`{"bar":0}`))
+	}))
+	defer svr.Close()
+	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
+	assert.Nil(t, err)
+
+	tests := []struct {
+		name   string
+		length int64
+	}{
+		{
+			name:   "negative",
+			length: -1,
+		},
+		{
+			name:   "zero",
+			length: 0,
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			resp, err := DoRequest(req)
+			resp.ContentLength = test.length
+			assert.Nil(t, err)
+			assert.Nil(t, Parse(resp, &val))
+			assert.Equal(t, 0, val.Bar)
+		})
+	}
+}
+
+func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
+	var val struct{}
+	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set(header.ContentType, header.JsonContentType)
+	}))
+	defer svr.Close()
+	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
+	assert.Nil(t, err)
+
+	tests := []struct {
+		name   string
+		length int64
+	}{
+		{
+			name:   "negative",
+			length: -1,
+		},
+		{
+			name:   "zero",
+			length: 0,
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			resp, err := DoRequest(req)
+			resp.ContentLength = test.length
+			assert.Nil(t, err)
+			assert.Nil(t, Parse(resp, &val))
+		})
+	}
+}
+
+func TestParseJsonBody_BodyError(t *testing.T) {
+	var val struct{}
+	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set(header.ContentType, header.JsonContentType)
+	}))
+	defer svr.Close()
+	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
+	assert.Nil(t, err)
+
+	resp, err := DoRequest(req)
+	resp.ContentLength = -1
+	resp.Body = mockedReader{}
+	assert.Nil(t, err)
+	assert.NotNil(t, Parse(resp, &val))
+}
+
+type mockedReader struct{}
+
+func (m mockedReader) Close() error {
+	return nil
+}
+
+func (m mockedReader) Read(p []byte) (n int, err error) {
+	return 0, errors.New("dummy")
+}