Răsfoiți Sursa

support cors in rest server

kevin 4 ani în urmă
părinte
comite
fe0d0687f5

+ 1 - 1
example/http/demo/main.go

@@ -56,7 +56,7 @@ func main() {
 		Port:     *port,
 		Timeout:  *timeout,
 		MaxConns: 500,
-	})
+	}, rest.WithNotAllowedHandler(rest.CorsHandler()))
 	defer engine.Stop()
 
 	engine.Use(first)

+ 29 - 0
rest/handlers.go

@@ -0,0 +1,29 @@
+package rest
+
+import (
+	"net/http"
+	"strings"
+)
+
+const (
+	allowOrigin  = "Access-Control-Allow-Origin"
+	allOrigin    = "*"
+	allowMethods = "Access-Control-Allow-Methods"
+	allowHeaders = "Access-Control-Allow-Headers"
+	headers      = "Content-Type, Content-Length, Origin"
+	methods      = "GET, HEAD, POST, PATCH, PUT, DELETE"
+	separator    = ", "
+)
+
+func CorsHandler(origins ...string) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if len(origins) > 0 {
+			w.Header().Set(allowOrigin, strings.Join(origins, separator))
+		} else {
+			w.Header().Set(allowOrigin, allOrigin)
+		}
+		w.Header().Set(allowMethods, methods)
+		w.Header().Set(allowHeaders, headers)
+		w.WriteHeader(http.StatusNoContent)
+	})
+}

+ 27 - 0
rest/handlers_test.go

@@ -0,0 +1,27 @@
+package rest
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCorsHandler(t *testing.T) {
+	w := httptest.NewRecorder()
+	handler := CorsHandler()
+	handler.ServeHTTP(w, nil)
+	assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
+	assert.Equal(t, allOrigin, w.Header().Get(allowOrigin))
+}
+
+func TestCorsHandlerWithOrigins(t *testing.T) {
+	origins := []string{"local", "remote"}
+	w := httptest.NewRecorder()
+	handler := CorsHandler(origins...)
+	handler.ServeHTTP(w, nil)
+	assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
+	assert.Equal(t, strings.Join(origins, separator), w.Header().Get(allowOrigin))
+}

+ 1 - 0
rest/httpx/router.go

@@ -6,4 +6,5 @@ type Router interface {
 	http.Handler
 	Handle(method string, path string, handler http.Handler) error
 	SetNotFoundHandler(handler http.Handler)
+	SetNotAllowedHandler(handler http.Handler)
 }

+ 16 - 5
rest/router/patrouter.go

@@ -22,8 +22,9 @@ var (
 )
 
 type patRouter struct {
-	trees    map[string]*search.Tree
-	notFound http.Handler
+	trees      map[string]*search.Tree
+	notFound   http.Handler
+	notAllowed http.Handler
 }
 
 func NewRouter() httpx.Router {
@@ -63,11 +64,17 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 
-	if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok {
+	allow, ok := pr.methodNotAllowed(r.Method, reqPath)
+	if !ok {
+		pr.handleNotFound(w, r)
+		return
+	}
+
+	if pr.notAllowed != nil {
+		pr.notAllowed.ServeHTTP(w, r)
+	} else {
 		w.Header().Set(allowHeader, allow)
 		w.WriteHeader(http.StatusMethodNotAllowed)
-	} else {
-		pr.handleNotFound(w, r)
 	}
 }
 
@@ -75,6 +82,10 @@ func (pr *patRouter) SetNotFoundHandler(handler http.Handler) {
 	pr.notFound = handler
 }
 
+func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) {
+	pr.notAllowed = handler
+}
+
 func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
 	if pr.notFound != nil {
 		pr.notFound.ServeHTTP(w, r)

+ 18 - 1
rest/router/patrouter_test.go

@@ -60,13 +60,30 @@ func TestPatRouterNotFound(t *testing.T) {
 	router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		notFound = true
 	}))
-	router.Handle(http.MethodGet, "/a/b", nil)
+	err := router.Handle(http.MethodGet, "/a/b",
+		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
+	assert.Nil(t, err)
 	r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
 	w := new(mockedResponseWriter)
 	router.ServeHTTP(w, r)
 	assert.True(t, notFound)
 }
 
+func TestPatRouterNotAllowed(t *testing.T) {
+	var notAllowed bool
+	router := NewRouter()
+	router.SetNotAllowedHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		notAllowed = true
+	}))
+	err := router.Handle(http.MethodGet, "/a/b",
+		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
+	assert.Nil(t, err)
+	r, _ := http.NewRequest(http.MethodPost, "/a/b", nil)
+	w := new(mockedResponseWriter)
+	router.ServeHTTP(w, r)
+	assert.True(t, notAllowed)
+}
+
 func TestPatRouter(t *testing.T) {
 	tests := []struct {
 		method string

+ 18 - 0
rest/server.go

@@ -1,12 +1,14 @@
 package rest
 
 import (
+	"errors"
 	"log"
 	"net/http"
 
 	"github.com/tal-tech/go-zero/core/logx"
 	"github.com/tal-tech/go-zero/rest/handler"
 	"github.com/tal-tech/go-zero/rest/httpx"
+	"github.com/tal-tech/go-zero/rest/router"
 )
 
 type (
@@ -32,6 +34,10 @@ func MustNewServer(c RestConf, opts ...RunOption) *Server {
 }
 
 func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
+	if len(opts) > 1 {
+		return nil, errors.New("only one RunOption is allowed")
+	}
+
 	if err := c.SetUp(); err != nil {
 		return nil, err
 	}
@@ -125,6 +131,18 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
 	return routes
 }
 
+func WithNotFoundHandler(handler http.Handler) RunOption {
+	rt := router.NewRouter()
+	rt.SetNotFoundHandler(handler)
+	return WithRouter(rt)
+}
+
+func WithNotAllowedHandler(handler http.Handler) RunOption {
+	rt := router.NewRouter()
+	rt.SetNotAllowedHandler(handler)
+	return WithRouter(rt)
+}
+
 func WithPriority() RouteOption {
 	return func(r *featuredRoutes) {
 		r.priority = true

+ 12 - 1
rest/server_test.go

@@ -12,6 +12,11 @@ import (
 	"github.com/tal-tech/go-zero/rest/router"
 )
 
+func TestNewServer(t *testing.T) {
+	_, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil))
+	assert.NotNil(t, err)
+}
+
 func TestWithMiddleware(t *testing.T) {
 	m := make(map[string]string)
 	router := router.NewRouter()
@@ -69,7 +74,7 @@ func TestWithMiddleware(t *testing.T) {
 	}, m)
 }
 
-func TestMultiMiddleware(t *testing.T) {
+func TestMultiMiddlewares(t *testing.T) {
 	m := make(map[string]string)
 	router := router.NewRouter()
 	handler := func(w http.ResponseWriter, r *http.Request) {
@@ -140,3 +145,9 @@ func TestMultiMiddleware(t *testing.T) {
 		"whatever": "200000200000",
 	}, m)
 }
+
+func TestWithPriority(t *testing.T) {
+	var fr featuredRoutes
+	WithPriority()(&fr)
+	assert.True(t, fr.priority)
+}