Răsfoiți Sursa

feat: simplify httpc (#1748)

* feat: simplify httpc

* chore: fix lint errors

* chore: fix log url issue

* chore: fix log url issue

* refactor: handle resp & err in ResponseHandler

* chore: remove unnecessary var names in return clause
Kevin Wan 3 ani în urmă
părinte
comite
78ea0769fd

+ 1 - 1
rest/httpc/internal/interceptor.go

@@ -4,5 +4,5 @@ import "net/http"
 
 type (
 	Interceptor     func(r *http.Request) (*http.Request, ResponseHandler)
-	ResponseHandler func(*http.Response)
+	ResponseHandler func(resp *http.Response, err error)
 )

+ 9 - 3
rest/httpc/internal/loginterceptor.go

@@ -10,15 +10,21 @@ import (
 
 func LogInterceptor(r *http.Request) (*http.Request, ResponseHandler) {
 	start := timex.Now()
-	return r, func(resp *http.Response) {
+	return r, func(resp *http.Response, err error) {
 		duration := timex.Since(start)
+		if err != nil {
+			logger := logx.WithContext(r.Context()).WithDuration(duration)
+			logger.Errorf("[HTTP] %s %s - %v", r.Method, r.URL, err)
+			return
+		}
+
 		var tc propagation.TraceContext
 		ctx := tc.Extract(r.Context(), propagation.HeaderCarrier(resp.Header))
 		logger := logx.WithContext(ctx).WithDuration(duration)
 		if isOkResponse(resp.StatusCode) {
-			logger.Infof("[HTTP] %d - %s %s/%s", resp.StatusCode, r.Method, r.Host, r.RequestURI)
+			logger.Infof("[HTTP] %d - %s %s", resp.StatusCode, r.Method, r.URL)
 		} else {
-			logger.Errorf("[HTTP] %d - %s %s/%s", resp.StatusCode, r.Method, r.Host, r.RequestURI)
+			logger.Errorf("[HTTP] %d - %s %s", resp.StatusCode, r.Method, r.URL)
 		}
 	}
 }

+ 17 - 2
rest/httpc/internal/loginterceptor_test.go

@@ -16,8 +16,8 @@ func TestLogInterceptor(t *testing.T) {
 	assert.Nil(t, err)
 	req, handler := LogInterceptor(req)
 	resp, err := http.DefaultClient.Do(req)
+	handler(resp, err)
 	assert.Nil(t, err)
-	handler(resp)
 	assert.Equal(t, http.StatusOK, resp.StatusCode)
 }
 
@@ -30,7 +30,22 @@ func TestLogInterceptorServerError(t *testing.T) {
 	assert.Nil(t, err)
 	req, handler := LogInterceptor(req)
 	resp, err := http.DefaultClient.Do(req)
+	handler(resp, err)
 	assert.Nil(t, err)
-	handler(resp)
 	assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
 }
+
+func TestLogInterceptorServerClosed(t *testing.T) {
+	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(http.StatusInternalServerError)
+	}))
+	defer svr.Close()
+	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
+	assert.Nil(t, err)
+	svr.Close()
+	req, handler := LogInterceptor(req)
+	resp, err := http.DefaultClient.Do(req)
+	handler(resp, err)
+	assert.NotNil(t, err)
+	assert.Nil(t, resp)
+}

+ 33 - 10
rest/httpc/requests.go

@@ -1,21 +1,44 @@
 package httpc
 
 import (
-	"io"
 	"net/http"
+
+	"github.com/zeromicro/go-zero/rest/httpc/internal"
 )
 
-// Do sends an HTTP request to the service assocated with the given key.
-func Do(key string, r *http.Request) (*http.Response, error) {
-	return NewService(key).Do(r)
+var interceptors = []internal.Interceptor{
+	internal.LogInterceptor,
+}
+
+// DoRequest sends an HTTP request and returns an HTTP response.
+func DoRequest(r *http.Request) (*http.Response, error) {
+	return request(r, defaultClient{})
 }
 
-// Get sends an HTTP GET request to the service assocated with the given key.
-func Get(key, url string) (*http.Response, error) {
-	return NewService(key).Get(url)
+type (
+	client interface {
+		do(r *http.Request) (*http.Response, error)
+	}
+
+	defaultClient struct{}
+)
+
+func (c defaultClient) do(r *http.Request) (*http.Response, error) {
+	return http.DefaultClient.Do(r)
 }
 
-// Post sends an HTTP POST request to the service assocated with the given key.
-func Post(key, url, contentType string, body io.Reader) (*http.Response, error) {
-	return NewService(key).Post(url, contentType, body)
+func request(r *http.Request, cli client) (*http.Response, error) {
+	var respHandlers []internal.ResponseHandler
+	for _, interceptor := range interceptors {
+		var h internal.ResponseHandler
+		r, h = interceptor(r)
+		respHandlers = append(respHandlers, h)
+	}
+
+	resp, err := cli.do(r)
+	for i := len(respHandlers) - 1; i >= 0; i-- {
+		respHandlers[i](resp, err)
+	}
+
+	return resp, err
 }

+ 8 - 7
rest/httpc/requests_test.go

@@ -12,9 +12,9 @@ func TestDo(t *testing.T) {
 	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	}))
 	defer svr.Close()
-	_, err := Get("foo", "tcp://bad request")
-	assert.NotNil(t, err)
-	resp, err := Get("foo", svr.URL)
+	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
+	assert.Nil(t, err)
+	resp, err := DoRequest(req)
 	assert.Nil(t, err)
 	assert.Equal(t, http.StatusOK, resp.StatusCode)
 }
@@ -22,9 +22,10 @@ func TestDo(t *testing.T) {
 func TestDoNotFound(t *testing.T) {
 	svr := httptest.NewServer(http.NotFoundHandler())
 	defer svr.Close()
-	_, err := Post("foo", "tcp://bad request", "application/json", nil)
-	assert.NotNil(t, err)
-	resp, err := Post("foo", svr.URL, "application/json", nil)
+	req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
+	assert.Nil(t, err)
+	req.Header.Set("Content-Type", "application/json")
+	resp, err := DoRequest(req)
 	assert.Nil(t, err)
 	assert.Equal(t, http.StatusNotFound, resp.StatusCode)
 }
@@ -34,7 +35,7 @@ func TestDoMoved(t *testing.T) {
 	defer svr.Close()
 	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
 	assert.Nil(t, err)
-	_, err = Do("foo", req)
+	_, err = DoRequest(req)
 	// too many redirects
 	assert.NotNil(t, err)
 }

+ 9 - 3
rest/httpc/responses_test.go

@@ -20,7 +20,9 @@ func TestParse(t *testing.T) {
 		w.Write([]byte(`{"name":"kevin","value":100}`))
 	}))
 	defer svr.Close()
-	resp, err := Get("foo", svr.URL)
+	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
+	assert.Nil(t, err)
+	resp, err := DoRequest(req)
 	assert.Nil(t, err)
 	assert.Nil(t, Parse(resp, &val))
 	assert.Equal(t, "bar", val.Foo)
@@ -37,7 +39,9 @@ func TestParseHeaderError(t *testing.T) {
 		w.Header().Set(contentType, applicationJson)
 	}))
 	defer svr.Close()
-	resp, err := Get("foo", svr.URL)
+	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
+	assert.Nil(t, err)
+	resp, err := DoRequest(req)
 	assert.Nil(t, err)
 	assert.NotNil(t, Parse(resp, &val))
 }
@@ -51,7 +55,9 @@ func TestParseNoBody(t *testing.T) {
 		w.Header().Set(contentType, applicationJson)
 	}))
 	defer svr.Close()
-	resp, err := Get("foo", svr.URL)
+	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
+	assert.Nil(t, err)
+	resp, err := DoRequest(req)
 	assert.Nil(t, err)
 	assert.Nil(t, Parse(resp, &val))
 	assert.Equal(t, "bar", val.Foo)

+ 6 - 55
rest/httpc/service.go

@@ -1,30 +1,19 @@
 package httpc
 
 import (
-	"io"
 	"net/http"
 
 	"github.com/zeromicro/go-zero/core/breaker"
-	"github.com/zeromicro/go-zero/core/logx"
-	"github.com/zeromicro/go-zero/rest/httpc/internal"
 )
 
-var interceptors = []internal.Interceptor{
-	internal.LogInterceptor,
-}
-
 type (
 	// Option is used to customize the *http.Client.
 	Option func(r *http.Request) *http.Request
 
 	// Service represents a remote HTTP service.
 	Service interface {
-		// Do sends an HTTP request to the service.
-		Do(r *http.Request) (*http.Response, error)
-		// Get sends an HTTP GET request to the service.
-		Get(url string) (*http.Response, error)
-		// Post sends an HTTP POST request to the service.
-		Post(url, contentType string, body io.Reader) (*http.Response, error)
+		// DoRequest sends a HTTP request to the service.
+		DoRequest(r *http.Request) (*http.Response, error)
 	}
 
 	namedService struct {
@@ -50,50 +39,12 @@ func NewServiceWithClient(name string, cli *http.Client, opts ...Option) Service
 	}
 }
 
-// Do sends an HTTP request to the service.
-func (s namedService) Do(r *http.Request) (resp *http.Response, err error) {
-	var respHandlers []internal.ResponseHandler
-	for _, interceptor := range interceptors {
-		var h internal.ResponseHandler
-		r, h = interceptor(r)
-		respHandlers = append(respHandlers, h)
-	}
-
-	resp, err = s.doRequest(r)
-	if err != nil {
-		logx.Errorf("[HTTP] %s %s/%s - %v", r.Method, r.Host, r.RequestURI, err)
-		return
-	}
-
-	for i := len(respHandlers) - 1; i >= 0; i-- {
-		respHandlers[i](resp)
-	}
-
-	return
-}
-
-// Get sends an HTTP GET request to the service.
-func (s namedService) Get(url string) (*http.Response, error) {
-	r, err := http.NewRequest(http.MethodGet, url, nil)
-	if err != nil {
-		return nil, err
-	}
-
-	return s.Do(r)
-}
-
-// Post sends an HTTP POST request to the service.
-func (s namedService) Post(url, ctype string, body io.Reader) (*http.Response, error) {
-	r, err := http.NewRequest(http.MethodPost, url, body)
-	if err != nil {
-		return nil, err
-	}
-
-	r.Header.Set(contentType, ctype)
-	return s.Do(r)
+// DoRequest sends an HTTP request to the service.
+func (s namedService) DoRequest(r *http.Request) (*http.Response, error) {
+	return request(r, s)
 }
 
-func (s namedService) doRequest(r *http.Request) (resp *http.Response, err error) {
+func (s namedService) do(r *http.Request) (resp *http.Response, err error) {
 	for _, opt := range s.opts {
 		r = opt(r)
 	}

+ 8 - 5
rest/httpc/service_test.go

@@ -14,7 +14,7 @@ func TestNamedService_Do(t *testing.T) {
 	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
 	assert.Nil(t, err)
 	service := NewService("foo")
-	_, err = service.Do(req)
+	_, err = service.DoRequest(req)
 	// too many redirects
 	assert.NotNil(t, err)
 }
@@ -28,7 +28,9 @@ func TestNamedService_Get(t *testing.T) {
 		r.Header.Set("foo", "bar")
 		return r
 	})
-	resp, err := service.Get(svr.URL)
+	req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
+	assert.Nil(t, err)
+	resp, err := service.DoRequest(req)
 	assert.Nil(t, err)
 	assert.Equal(t, http.StatusOK, resp.StatusCode)
 	assert.Equal(t, "bar", resp.Header.Get("foo"))
@@ -38,9 +40,10 @@ func TestNamedService_Post(t *testing.T) {
 	svr := httptest.NewServer(http.NotFoundHandler())
 	defer svr.Close()
 	service := NewService("foo")
-	_, err := service.Post("tcp://bad request", "application/json", nil)
-	assert.NotNil(t, err)
-	resp, err := service.Post(svr.URL, "application/json", nil)
+	req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
+	assert.Nil(t, err)
+	req.Header.Set("Content-Type", "application/json")
+	resp, err := service.DoRequest(req)
 	assert.Nil(t, err)
 	assert.Equal(t, http.StatusNotFound, resp.StatusCode)
 }