Browse Source

feat: support CORS by using rest.WithCors(...) (#1212)

* feat: support CORS by using rest.WithCors(...)

* chore: add comments

* refactor: lowercase unexported methods

* ci: fix lint errors
Kevin Wan 3 years ago
parent
commit
c28e01fed3

+ 26 - 31
rest/engine.go

@@ -14,7 +14,6 @@ import (
 	"github.com/tal-tech/go-zero/rest/handler"
 	"github.com/tal-tech/go-zero/rest/httpx"
 	"github.com/tal-tech/go-zero/rest/internal"
-	"github.com/tal-tech/go-zero/rest/router"
 )
 
 // use 1000m to represent 100%
@@ -47,39 +46,10 @@ func newEngine(c RestConf) *engine {
 	return srv
 }
 
-func (ng *engine) AddRoutes(r featuredRoutes) {
+func (ng *engine) addRoutes(r featuredRoutes) {
 	ng.routes = append(ng.routes, r)
 }
 
-func (ng *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
-	ng.unauthorizedCallback = callback
-}
-
-func (ng *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
-	ng.unsignedCallback = callback
-}
-
-func (ng *engine) Start() error {
-	return ng.StartWithRouter(router.NewRouter())
-}
-
-func (ng *engine) StartWithRouter(router httpx.Router) error {
-	if err := ng.bindRoutes(router); err != nil {
-		return err
-	}
-
-	if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
-		return internal.StartHttp(ng.conf.Host, ng.conf.Port, router)
-	}
-
-	return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
-		ng.conf.KeyFile, router, func(srv *http.Server) {
-			if ng.tlsConfig != nil {
-				srv.TLSConfig = ng.tlsConfig
-			}
-		})
-}
-
 func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
 	verifier func(alice.Chain) alice.Chain) alice.Chain {
 	if fr.jwt.enabled {
@@ -188,6 +158,14 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) {
 	ng.tlsConfig = cfg
 }
 
+func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
+	ng.unauthorizedCallback = callback
+}
+
+func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
+	ng.unsignedCallback = callback
+}
+
 func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
 	if !signature.enabled {
 		return func(chain alice.Chain) alice.Chain {
@@ -228,6 +206,23 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
 	}, nil
 }
 
+func (ng *engine) start(router httpx.Router) error {
+	if err := ng.bindRoutes(router); err != nil {
+		return err
+	}
+
+	if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
+		return internal.StartHttp(ng.conf.Host, ng.conf.Port, router)
+	}
+
+	return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
+		ng.conf.KeyFile, router, func(srv *http.Server) {
+			if ng.tlsConfig != nil {
+				srv.TLSConfig = ng.tlsConfig
+			}
+		})
+}
+
 func (ng *engine) use(middleware Middleware) {
 	ng.middlewares = append(ng.middlewares, middleware)
 }

+ 2 - 2
rest/engine_test.go

@@ -144,13 +144,13 @@ Verbose: true
 			var cnf RestConf
 			assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf))
 			ng := newEngine(cnf)
-			ng.AddRoutes(route)
+			ng.addRoutes(route)
 			ng.use(func(next http.HandlerFunc) http.HandlerFunc {
 				return func(w http.ResponseWriter, r *http.Request) {
 					next.ServeHTTP(w, r)
 				}
 			})
-			assert.NotNil(t, ng.StartWithRouter(mockedRouter{}))
+			assert.NotNil(t, ng.start(mockedRouter{}))
 		}
 	}
 }

+ 0 - 27
rest/handlers.go

@@ -1,27 +0,0 @@
-package rest
-
-import "net/http"
-
-const (
-	allowOrigin  = "Access-Control-Allow-Origin"
-	allOrigins   = "*"
-	allowMethods = "Access-Control-Allow-Methods"
-	allowHeaders = "Access-Control-Allow-Headers"
-	headers      = "Content-Type, Content-Length, Origin"
-	methods      = "GET, HEAD, POST, PATCH, PUT, DELETE"
-)
-
-// CorsHandler handles cross domain OPTIONS requests.
-// At most one origin can be specified, other origins are ignored if given.
-func CorsHandler(origins ...string) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		if len(origins) > 0 {
-			w.Header().Set(allowOrigin, origins[0])
-		} else {
-			w.Header().Set(allowOrigin, allOrigins)
-		}
-		w.Header().Set(allowMethods, methods)
-		w.Header().Set(allowHeaders, headers)
-		w.WriteHeader(http.StatusNoContent)
-	})
-}

+ 0 - 42
rest/handlers_test.go

@@ -1,42 +0,0 @@
-package rest
-
-import (
-	"net/http"
-	"net/http/httptest"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestCorsHandlerWithOrigins(t *testing.T) {
-	tests := []struct {
-		name    string
-		origins []string
-		expect  string
-	}{
-		{
-			name:   "allow all origins",
-			expect: allOrigins,
-		},
-		{
-			name:    "allow one origin",
-			origins: []string{"local"},
-			expect:  "local",
-		},
-		{
-			name:    "allow many origins",
-			origins: []string{"local", "remote"},
-			expect:  "local",
-		},
-	}
-
-	for _, test := range tests {
-		t.Run(test.name, func(t *testing.T) {
-			w := httptest.NewRecorder()
-			handler := CorsHandler(test.origins...)
-			handler.ServeHTTP(w, nil)
-			assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
-			assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
-		})
-	}
-}

+ 64 - 0
rest/internal/cors/handlers.go

@@ -0,0 +1,64 @@
+package cors
+
+import "net/http"
+
+const (
+	allowOrigin      = "Access-Control-Allow-Origin"
+	allOrigins       = "*"
+	allowMethods     = "Access-Control-Allow-Methods"
+	allowHeaders     = "Access-Control-Allow-Headers"
+	allowCredentials = "Access-Control-Allow-Credentials"
+	exposeHeaders    = "Access-Control-Expose-Headers"
+	allowHeadersVal  = "Content-Type, Origin, X-CSRF-Token, Authorization, AccessToken, Token, Range"
+	exposeHeadersVal = "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers"
+	methods          = "GET, HEAD, POST, PATCH, PUT, DELETE"
+	allowTrue        = "true"
+	maxAgeHeader     = "Access-Control-Max-Age"
+	maxAgeHeaderVal  = "86400"
+)
+
+// Handler handles cross domain not allowed requests.
+// At most one origin can be specified, other origins are ignored if given, default to be *.
+func Handler(origin ...string) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		setHeader(w, getOrigin(origin))
+
+		if r.Method != http.MethodOptions {
+			w.WriteHeader(http.StatusNotFound)
+		} else {
+			w.WriteHeader(http.StatusNoContent)
+		}
+	})
+}
+
+// Middleware returns a middleware that adds CORS headers to the response.
+func Middleware(origin ...string) func(http.HandlerFunc) http.HandlerFunc {
+	return func(next http.HandlerFunc) http.HandlerFunc {
+		return func(w http.ResponseWriter, r *http.Request) {
+			setHeader(w, getOrigin(origin))
+
+			if r.Method == http.MethodOptions {
+				w.WriteHeader(http.StatusNoContent)
+			} else {
+				next(w, r)
+			}
+		}
+	}
+}
+
+func getOrigin(origins []string) string {
+	if len(origins) > 0 {
+		return origins[0]
+	} else {
+		return allOrigins
+	}
+}
+
+func setHeader(w http.ResponseWriter, origin string) {
+	w.Header().Set(allowOrigin, origin)
+	w.Header().Set(allowMethods, methods)
+	w.Header().Set(allowHeaders, allowHeadersVal)
+	w.Header().Set(exposeHeaders, exposeHeadersVal)
+	w.Header().Set(allowCredentials, allowTrue)
+	w.Header().Set(maxAgeHeader, maxAgeHeaderVal)
+}

+ 76 - 0
rest/internal/cors/handlers_test.go

@@ -0,0 +1,76 @@
+package cors
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCorsHandlerWithOrigins(t *testing.T) {
+	tests := []struct {
+		name    string
+		origins []string
+		expect  string
+	}{
+		{
+			name:   "allow all origins",
+			expect: allOrigins,
+		},
+		{
+			name:    "allow one origin",
+			origins: []string{"local"},
+			expect:  "local",
+		},
+		{
+			name:    "allow many origins",
+			origins: []string{"local", "remote"},
+			expect:  "local",
+		},
+	}
+
+	methods := []string{
+		http.MethodOptions,
+		http.MethodGet,
+		http.MethodPost,
+	}
+
+	for _, test := range tests {
+		for _, method := range methods {
+			test := test
+			t.Run(test.name+"-handler", func(t *testing.T) {
+				r := httptest.NewRequest(method, "http://localhost", nil)
+				w := httptest.NewRecorder()
+				handler := Handler(test.origins...)
+				handler.ServeHTTP(w, r)
+				if method == http.MethodOptions {
+					assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
+				} else {
+					assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
+				}
+				assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
+			})
+		}
+	}
+
+	for _, test := range tests {
+		for _, method := range methods {
+			test := test
+			t.Run(test.name+"-middleware", func(t *testing.T) {
+				r := httptest.NewRequest(method, "http://localhost", nil)
+				w := httptest.NewRecorder()
+				handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
+					w.WriteHeader(http.StatusOK)
+				})
+				handler.ServeHTTP(w, r)
+				if method == http.MethodOptions {
+					assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
+				} else {
+					assert.Equal(t, http.StatusOK, w.Result().StatusCode)
+				}
+				assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
+			})
+		}
+	}
+}

+ 24 - 25
rest/server.go

@@ -10,21 +10,18 @@ import (
 	"github.com/tal-tech/go-zero/core/logx"
 	"github.com/tal-tech/go-zero/rest/handler"
 	"github.com/tal-tech/go-zero/rest/httpx"
+	"github.com/tal-tech/go-zero/rest/internal/cors"
 	"github.com/tal-tech/go-zero/rest/router"
 )
 
 type (
-	runOptions struct {
-		start func(*engine) error
-	}
-
 	// RunOption defines the method to customize a Server.
 	RunOption func(*Server)
 
 	// A Server is a http server.
 	Server struct {
-		ngin *engine
-		opts runOptions
+		ngin   *engine
+		router httpx.Router
 	}
 )
 
@@ -48,12 +45,8 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
 	}
 
 	server := &Server{
-		ngin: newEngine(c),
-		opts: runOptions{
-			start: func(ng *engine) error {
-				return ng.Start()
-			},
-		},
+		ngin:   newEngine(c),
+		router: router.NewRouter(),
 	}
 
 	for _, opt := range opts {
@@ -71,7 +64,7 @@ func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
 	for _, opt := range opts {
 		opt(&r)
 	}
-	s.ngin.AddRoutes(r)
+	s.ngin.addRoutes(r)
 }
 
 // AddRoute adds given route into the Server.
@@ -83,7 +76,7 @@ func (s *Server) AddRoute(r Route, opts ...RouteOption) {
 // Graceful shutdown is enabled by default.
 // Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
 func (s *Server) Start() {
-	handleError(s.opts.start(s.ngin))
+	handleError(s.ngin.start(s.router))
 }
 
 // Stop stops the Server.
@@ -103,6 +96,14 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
 	}
 }
 
+// WithCors returns a func to enable CORS for given origin, or default to all origins (*).
+func WithCors(origin ...string) RunOption {
+	return func(server *Server) {
+		server.router.SetNotAllowedHandler(cors.Handler(origin...))
+		server.Use(cors.Middleware(origin...))
+	}
+}
+
 // WithJwt returns a func to enable jwt authentication in given route.
 func WithJwt(secret string) RouteOption {
 	return func(r *featuredRoutes) {
@@ -151,16 +152,16 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
 
 // WithNotFoundHandler returns a RunOption with not found handler set to given handler.
 func WithNotFoundHandler(handler http.Handler) RunOption {
-	rt := router.NewRouter()
-	rt.SetNotFoundHandler(handler)
-	return WithRouter(rt)
+	return func(server *Server) {
+		server.router.SetNotFoundHandler(handler)
+	}
 }
 
 // WithNotAllowedHandler returns a RunOption with not allowed handler set to given handler.
 func WithNotAllowedHandler(handler http.Handler) RunOption {
-	rt := router.NewRouter()
-	rt.SetNotAllowedHandler(handler)
-	return WithRouter(rt)
+	return func(server *Server) {
+		server.router.SetNotAllowedHandler(handler)
+	}
 }
 
 // WithPrefix adds group as a prefix to the route paths.
@@ -189,9 +190,7 @@ func WithPriority() RouteOption {
 // WithRouter returns a RunOption that make server run with given router.
 func WithRouter(router httpx.Router) RunOption {
 	return func(server *Server) {
-		server.opts.start = func(ng *engine) error {
-			return ng.StartWithRouter(router)
-		}
+		server.router = router
 	}
 }
 
@@ -222,14 +221,14 @@ func WithTLSConfig(cfg *tls.Config) RunOption {
 // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
 func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
 	return func(srv *Server) {
-		srv.ngin.SetUnauthorizedCallback(callback)
+		srv.ngin.setUnauthorizedCallback(callback)
 	}
 }
 
 // WithUnsignedCallback returns a RunOption that with given unsigned callback set.
 func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
 	return func(srv *Server) {
-		srv.ngin.SetUnsignedCallback(callback)
+		srv.ngin.setUnsignedCallback(callback)
 	}
 }
 

+ 44 - 16
rest/server_test.go

@@ -22,11 +22,6 @@ Port: 54321
 `
 	var cnf RestConf
 	assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
-	failStart := func(server *Server) {
-		server.opts.start = func(e *engine) error {
-			return http.ErrServerClosed
-		}
-	}
 
 	tests := []struct {
 		c    RestConf
@@ -35,38 +30,40 @@ Port: 54321
 	}{
 		{
 			c:    RestConf{},
-			opts: []RunOption{failStart},
+			opts: []RunOption{WithRouter(mockedRouter{}), WithCors()},
 			fail: true,
 		},
 		{
 			c:    cnf,
-			opts: []RunOption{failStart},
+			opts: []RunOption{WithRouter(mockedRouter{})},
 		},
 		{
 			c:    cnf,
-			opts: []RunOption{WithNotAllowedHandler(nil), failStart},
+			opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)},
 		},
 		{
 			c:    cnf,
-			opts: []RunOption{WithNotFoundHandler(nil), failStart},
+			opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})},
 		},
 		{
 			c:    cnf,
-			opts: []RunOption{WithUnauthorizedCallback(nil), failStart},
+			opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})},
 		},
 		{
 			c:    cnf,
-			opts: []RunOption{WithUnsignedCallback(nil), failStart},
+			opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})},
 		},
 	}
 
 	for _, test := range tests {
-		srv, err := NewServer(test.c, test.opts...)
+		var srv *Server
+		var err error
 		if test.fail {
+			_, err = NewServer(test.c, test.opts...)
 			assert.NotNil(t, err)
-		}
-		if err != nil {
 			continue
+		} else {
+			srv = MustNewServer(test.c, test.opts...)
 		}
 
 		srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
@@ -80,8 +77,21 @@ Port: 54321
 			Handler: nil,
 		}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
 			WithJwtTransition("preivous", "thenewone"))
-		srv.Start()
-		srv.Stop()
+
+		func() {
+			defer func() {
+				p := recover()
+				switch v := p.(type) {
+				case error:
+					assert.Equal(t, "foo", v.Error())
+				default:
+					t.Fail()
+				}
+			}()
+
+			srv.Start()
+			srv.Stop()
+		}()
 	}
 }
 
@@ -180,6 +190,9 @@ func TestMultiMiddlewares(t *testing.T) {
 				next.ServeHTTP(w, r)
 			}
 		},
+		ToMiddleware(func(next http.Handler) http.Handler {
+			return next
+		}),
 	}, Route{
 		Method:  http.MethodGet,
 		Path:    "/first/:name/:year",
@@ -282,3 +295,18 @@ Port: 54321
 		assert.Equal(t, srv.ngin.tlsConfig, testCase.res)
 	}
 }
+
+func TestWithCors(t *testing.T) {
+	const configYaml = `
+Name: foo
+Port: 54321
+`
+	var cnf RestConf
+	assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
+	rt := router.NewRouter()
+	srv, err := NewServer(cnf, WithRouter(rt))
+	assert.Nil(t, err)
+
+	opt := WithCors("local")
+	opt(srv)
+}

+ 2 - 2
tools/goctl/api/gogen/genroutes.go

@@ -27,12 +27,12 @@ import (
 	{{.importPackages}}
 )
 
-func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) {
+func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
 	{{.routesAdditions}}
 }
 `
 	routesAdditionTemplate = `
-	engine.AddRoutes(
+	server.AddRoutes(
 		{{.routes}} {{.jwt}}{{.signature}} {{.prefix}}
 	)
 `