Browse Source

add error handle tests

kevin 4 years ago
parent
commit
9592639cb4
2 changed files with 72 additions and 13 deletions
  1. 14 7
      rest/httpx/responses.go
  2. 58 6
      rest/httpx/responses_test.go

+ 14 - 7
rest/httpx/responses.go

@@ -9,16 +9,27 @@ import (
 )
 )
 
 
 var (
 var (
-	errorHandler = defaultErrorHandler
+	errorHandler func(error) (int, interface{})
 	lock         sync.RWMutex
 	lock         sync.RWMutex
 )
 )
 
 
 func Error(w http.ResponseWriter, err error) {
 func Error(w http.ResponseWriter, err error) {
 	lock.RLock()
 	lock.RLock()
-	code, body := errorHandler(err)
+	handler := errorHandler
 	lock.RUnlock()
 	lock.RUnlock()
 
 
-	WriteJson(w, code, body)
+	if handler == nil {
+		http.Error(w, err.Error(), http.StatusBadRequest)
+		return
+	}
+
+	code, body := errorHandler(err)
+	e, ok := body.(error)
+	if ok {
+		http.Error(w, e.Error(), code)
+	} else {
+		WriteJson(w, code, body)
+	}
 }
 }
 
 
 func Ok(w http.ResponseWriter) {
 func Ok(w http.ResponseWriter) {
@@ -51,7 +62,3 @@ func WriteJson(w http.ResponseWriter, code int, v interface{}) {
 		logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
 		logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
 	}
 	}
 }
 }
-
-func defaultErrorHandler(err error) (int, interface{}) {
-	return http.StatusBadRequest, err
-}

+ 58 - 6
rest/httpx/responses_test.go

@@ -19,13 +19,65 @@ func init() {
 }
 }
 
 
 func TestError(t *testing.T) {
 func TestError(t *testing.T) {
-	const body = "foo"
-	w := tracedResponseWriter{
-		headers: make(map[string][]string),
+	const (
+		body        = "foo"
+		wrappedBody = `"foo"`
+	)
+
+	tests := []struct {
+		name         string
+		input        string
+		errorHandler func(error) (int, interface{})
+		expectBody   string
+		expectCode   int
+	}{
+		{
+			name:       "default error handler",
+			input:      body,
+			expectBody: body,
+			expectCode: http.StatusBadRequest,
+		},
+		{
+			name:  "customized error handler return string",
+			input: body,
+			errorHandler: func(err error) (int, interface{}) {
+				return http.StatusForbidden, err.Error()
+			},
+			expectBody: wrappedBody,
+			expectCode: http.StatusForbidden,
+		},
+		{
+			name:  "customized error handler return error",
+			input: body,
+			errorHandler: func(err error) (int, interface{}) {
+				return http.StatusForbidden, err
+			},
+			expectBody: body,
+			expectCode: http.StatusForbidden,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			w := tracedResponseWriter{
+				headers: make(map[string][]string),
+			}
+			if test.errorHandler != nil {
+				lock.RLock()
+				prev := errorHandler
+				lock.RUnlock()
+				SetErrorHandler(test.errorHandler)
+				defer func() {
+					lock.Lock()
+					errorHandler = prev
+					lock.Unlock()
+				}()
+			}
+			Error(&w, errors.New(test.input))
+			assert.Equal(t, test.expectCode, w.code)
+			assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
+		})
 	}
 	}
-	Error(&w, errors.New(body))
-	assert.Equal(t, http.StatusBadRequest, w.code)
-	assert.Equal(t, body, strings.TrimSpace(w.builder.String()))
 }
 }
 
 
 func TestOk(t *testing.T) {
 func TestOk(t *testing.T) {