瀏覽代碼

refactor: httpc package for easy to use (#1645)

Kevin Wan 3 年之前
父節點
當前提交
f9e6013a6c
共有 4 個文件被更改,包括 24 次插入22 次删除
  1. 6 6
      rest/httpc/requests.go
  2. 1 3
      rest/httpc/requests_test.go
  3. 12 11
      rest/httpc/service.go
  4. 5 2
      rest/httpc/service_test.go

+ 6 - 6
rest/httpc/requests.go

@@ -6,16 +6,16 @@ import (
 )
 
 // Do sends an HTTP request to the service assocated with the given key.
-func Do(key string, r *http.Request, opts ...Option) (*http.Response, error) {
-	return NewService(key, opts...).Do(r)
+func Do(key string, r *http.Request) (*http.Response, error) {
+	return NewService(key).Do(r)
 }
 
 // Get sends an HTTP GET request to the service assocated with the given key.
-func Get(key, url string, opts ...Option) (*http.Response, error) {
-	return NewService(key, opts...).Get(url)
+func Get(key, url string) (*http.Response, error) {
+	return NewService(key).Get(url)
 }
 
 // Post sends an HTTP POST request to the service assocated with the given key.
-func Post(key, url, contentType string, body io.Reader, opts ...Option) (*http.Response, error) {
-	return NewService(key, opts...).Post(url, contentType, body)
+func Post(key, url, contentType string, body io.Reader) (*http.Response, error) {
+	return NewService(key).Post(url, contentType, body)
 }

+ 1 - 3
rest/httpc/requests_test.go

@@ -13,9 +13,7 @@ func TestDo(t *testing.T) {
 	}))
 	_, err := Get("foo", "tcp://bad request")
 	assert.NotNil(t, err)
-	resp, err := Get("foo", svr.URL, func(cli *http.Client) {
-		cli.Transport = http.DefaultTransport
-	})
+	resp, err := Get("foo", svr.URL)
 	assert.Nil(t, err)
 	assert.Equal(t, http.StatusOK, resp.StatusCode)
 }

+ 12 - 11
rest/httpc/service.go

@@ -18,7 +18,7 @@ var interceptors = []internal.Interceptor{
 
 type (
 	// Option is used to customize the *http.Client.
-	Option func(cli *http.Client)
+	Option func(r *http.Request) *http.Request
 
 	// Service represents a remote HTTP service.
 	Service interface {
@@ -33,26 +33,23 @@ type (
 	namedService struct {
 		name string
 		cli  *http.Client
+		opts []Option
 	}
 )
 
 // NewService returns a remote service with the given name.
 // opts are used to customize the *http.Client.
 func NewService(name string, opts ...Option) Service {
-	var cli *http.Client
-
-	if len(opts) == 0 {
-		cli = http.DefaultClient
-	} else {
-		cli = &http.Client{}
-		for _, opt := range opts {
-			opt(cli)
-		}
-	}
+	return NewServiceWithClient(name, http.DefaultClient, opts...)
+}
 
+// NewServiceWithClient returns a remote service with the given name.
+// opts are used to customize the *http.Client.
+func NewServiceWithClient(name string, cli *http.Client, opts ...Option) Service {
 	return namedService{
 		name: name,
 		cli:  cli,
+		opts: opts,
 	}
 }
 
@@ -100,6 +97,10 @@ func (s namedService) Post(url, contentType string, body io.Reader) (*http.Respo
 }
 
 func (s namedService) doRequest(r *http.Request) (resp *http.Response, err error) {
+	for _, opt := range s.opts {
+		r = opt(r)
+	}
+
 	brk := breaker.GetBreaker(s.name)
 	err = brk.DoWithAcceptable(func() error {
 		resp, err = s.cli.Do(r)

+ 5 - 2
rest/httpc/service_test.go

@@ -20,13 +20,16 @@ func TestNamedService_Do(t *testing.T) {
 
 func TestNamedService_Get(t *testing.T) {
 	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set("foo", r.Header.Get("foo"))
 	}))
-	service := NewService("foo", func(cli *http.Client) {
-		cli.Transport = http.DefaultTransport
+	service := NewService("foo", func(r *http.Request) *http.Request {
+		r.Header.Set("foo", "bar")
+		return r
 	})
 	resp, err := service.Get(svr.URL)
 	assert.Nil(t, err)
 	assert.Equal(t, http.StatusOK, resp.StatusCode)
+	assert.Equal(t, "bar", resp.Header.Get("foo"))
 }
 
 func TestNamedService_Post(t *testing.T) {