瀏覽代碼

fix #1070 (#1389)

* fix #1070

* test: add more tests
Kevin Wan 3 年之前
父節點
當前提交
62266d8f91
共有 4 個文件被更改,包括 27 次插入9 次删除
  1. 9 6
      rest/handler/timeouthandler.go
  2. 6 2
      rest/httpx/responses.go
  3. 12 0
      rest/httpx/responses_test.go
  4. 0 1
      tools/goctl/api/apigen/gen.go

+ 9 - 6
rest/handler/timeouthandler.go

@@ -13,6 +13,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/tal-tech/go-zero/rest/httpx"
 	"github.com/tal-tech/go-zero/rest/internal"
 )
 
@@ -91,12 +92,14 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		defer tw.mu.Unlock()
 		// there isn't any user-defined middleware before TimoutHandler,
 		// so we can guarantee that cancelation in biz related code won't come here.
-		if errors.Is(ctx.Err(), context.Canceled) {
-			w.WriteHeader(statusClientClosedRequest)
-		} else {
-			w.WriteHeader(http.StatusServiceUnavailable)
-		}
-		io.WriteString(w, h.errorBody())
+		httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) {
+			if errors.Is(err, context.Canceled) {
+				w.WriteHeader(statusClientClosedRequest)
+			} else {
+				w.WriteHeader(http.StatusServiceUnavailable)
+			}
+			io.WriteString(w, h.errorBody())
+		})
 		tw.timedOut = true
 	}
 }

+ 6 - 2
rest/httpx/responses.go

@@ -14,13 +14,17 @@ var (
 )
 
 // Error writes err into w.
-func Error(w http.ResponseWriter, err error) {
+func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
 	lock.RLock()
 	handler := errorHandler
 	lock.RUnlock()
 
 	if handler == nil {
-		http.Error(w, err.Error(), http.StatusBadRequest)
+		if len(fns) > 0 {
+			fns[0](w, err)
+		} else {
+			http.Error(w, err.Error(), http.StatusBadRequest)
+		}
 		return
 	}
 

+ 12 - 0
rest/httpx/responses_test.go

@@ -95,6 +95,18 @@ func TestError(t *testing.T) {
 	}
 }
 
+func TestErrorWithHandler(t *testing.T) {
+	w := tracedResponseWriter{
+		headers: make(map[string][]string),
+	}
+	Error(&w, errors.New("foo"), func(w http.ResponseWriter, err error) {
+		http.Error(w, err.Error(), 499)
+	})
+	assert.Equal(t, 499, w.code)
+	assert.True(t, w.hasBody)
+	assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
+}
+
 func TestOk(t *testing.T) {
 	w := tracedResponseWriter{
 		headers: make(map[string][]string),

+ 0 - 1
tools/goctl/api/apigen/gen.go

@@ -52,7 +52,6 @@ func ApiCommand(c *cli.Context) error {
 	}
 	defer fp.Close()
 
-
 	home := c.String("home")
 	remote := c.String("remote")
 	if len(remote) > 0 {