|
@@ -1,13 +1,17 @@
|
|
|
package rest
|
|
|
|
|
|
import (
|
|
|
+ "context"
|
|
|
"errors"
|
|
|
"net/http"
|
|
|
+ "net/http/httptest"
|
|
|
+ "sync/atomic"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/zeromicro/go-zero/core/conf"
|
|
|
+ "github.com/zeromicro/go-zero/core/logx"
|
|
|
)
|
|
|
|
|
|
func TestNewEngine(t *testing.T) {
|
|
@@ -190,6 +194,75 @@ func TestEngine_checkedTimeout(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func TestEngine_notFoundHandler(t *testing.T) {
|
|
|
+ logx.Disable()
|
|
|
+
|
|
|
+ ng := newEngine(RestConf{})
|
|
|
+ ts := httptest.NewServer(ng.notFoundHandler(nil))
|
|
|
+ defer ts.Close()
|
|
|
+
|
|
|
+ client := ts.Client()
|
|
|
+ err := func(ctx context.Context) error {
|
|
|
+ req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
|
|
|
+ assert.Nil(t, err)
|
|
|
+ res, err := client.Do(req)
|
|
|
+ assert.Nil(t, err)
|
|
|
+ assert.Equal(t, http.StatusNotFound, res.StatusCode)
|
|
|
+ return res.Body.Close()
|
|
|
+ }(context.Background())
|
|
|
+
|
|
|
+ assert.Nil(t, err)
|
|
|
+}
|
|
|
+
|
|
|
+func TestEngine_notFoundHandlerNotNil(t *testing.T) {
|
|
|
+ logx.Disable()
|
|
|
+
|
|
|
+ ng := newEngine(RestConf{})
|
|
|
+ var called int32
|
|
|
+ ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ atomic.AddInt32(&called, 1)
|
|
|
+ })))
|
|
|
+ defer ts.Close()
|
|
|
+
|
|
|
+ client := ts.Client()
|
|
|
+ err := func(ctx context.Context) error {
|
|
|
+ req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
|
|
|
+ assert.Nil(t, err)
|
|
|
+ res, err := client.Do(req)
|
|
|
+ assert.Nil(t, err)
|
|
|
+ assert.Equal(t, http.StatusNotFound, res.StatusCode)
|
|
|
+ return res.Body.Close()
|
|
|
+ }(context.Background())
|
|
|
+
|
|
|
+ assert.Nil(t, err)
|
|
|
+ assert.Equal(t, int32(1), atomic.LoadInt32(&called))
|
|
|
+}
|
|
|
+
|
|
|
+func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
|
|
|
+ logx.Disable()
|
|
|
+
|
|
|
+ ng := newEngine(RestConf{})
|
|
|
+ var called int32
|
|
|
+ ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ atomic.AddInt32(&called, 1)
|
|
|
+ w.WriteHeader(http.StatusExpectationFailed)
|
|
|
+ })))
|
|
|
+ defer ts.Close()
|
|
|
+
|
|
|
+ client := ts.Client()
|
|
|
+ err := func(ctx context.Context) error {
|
|
|
+ req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
|
|
|
+ assert.Nil(t, err)
|
|
|
+ res, err := client.Do(req)
|
|
|
+ assert.Nil(t, err)
|
|
|
+ assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
|
|
|
+ return res.Body.Close()
|
|
|
+ }(context.Background())
|
|
|
+
|
|
|
+ assert.Nil(t, err)
|
|
|
+ assert.Equal(t, int32(1), atomic.LoadInt32(&called))
|
|
|
+}
|
|
|
+
|
|
|
type mockedRouter struct{}
|
|
|
|
|
|
func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|