Kevin Wan 2 år sedan
förälder
incheckning
4cb68a034a
2 ändrade filer med 39 tillägg och 2 borttagningar
  1. 18 2
      rest/server.go
  2. 21 0
      rest/server_test.go

+ 18 - 2
rest/server.go

@@ -126,7 +126,7 @@ func WithChain(chn chain.Chain) RunOption {
 func WithCors(origin ...string) RunOption {
 	return func(server *Server) {
 		server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...))
-		server.Use(cors.Middleware(nil, origin...))
+		server.router = newCorsRouter(server.router, nil, origin...)
 	}
 }
 
@@ -136,7 +136,7 @@ func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(htt
 	origin ...string) RunOption {
 	return func(server *Server) {
 		server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...))
-		server.Use(cors.Middleware(middlewareFn, origin...))
+		server.router = newCorsRouter(server.router, middlewareFn, origin...)
 	}
 }
 
@@ -291,3 +291,19 @@ func validateSecret(secret string) {
 		panic("secret's length can't be less than 8")
 	}
 }
+
+type corsRouter struct {
+	httpx.Router
+	middleware Middleware
+}
+
+func newCorsRouter(router httpx.Router, headerFn func(http.Header), origins ...string) httpx.Router {
+	return &corsRouter{
+		Router:     router,
+		middleware: cors.Middleware(headerFn, origins...),
+	}
+}
+
+func (c *corsRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	c.middleware(c.Router.ServeHTTP)(w, r)
+}

+ 21 - 0
rest/server_test.go

@@ -18,6 +18,7 @@ import (
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/rest/chain"
 	"github.com/zeromicro/go-zero/rest/httpx"
+	"github.com/zeromicro/go-zero/rest/internal/cors"
 	"github.com/zeromicro/go-zero/rest/router"
 )
 
@@ -515,3 +516,23 @@ func TestServer_WithChain(t *testing.T) {
 	rt.ServeHTTP(httptest.NewRecorder(), req)
 	assert.Equal(t, int32(5), atomic.LoadInt32(&called))
 }
+
+func TestServer_WithCors(t *testing.T) {
+	var called int32
+	middleware := func(next http.Handler) http.Handler {
+		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			atomic.AddInt32(&called, 1)
+			next.ServeHTTP(w, r)
+		})
+	}
+	r := router.NewRouter()
+	assert.Nil(t, r.Handle(http.MethodOptions, "/", middleware(http.NotFoundHandler())))
+
+	cr := &corsRouter{
+		Router:     r,
+		middleware: cors.Middleware(nil, "*"),
+	}
+	req := httptest.NewRequest(http.MethodOptions, "/", nil)
+	cr.ServeHTTP(httptest.NewRecorder(), req)
+	assert.Equal(t, int32(0), atomic.LoadInt32(&called))
+}