Browse Source

feat: support CORS, better implementation (#1217)

* feat: support CORS, better implementation

* chore: refine code
Kevin Wan 3 years ago
parent
commit
28409791fa
3 changed files with 78 additions and 22 deletions
  1. 46 11
      rest/internal/cors/handlers.go
  2. 31 10
      rest/internal/cors/handlers_test.go
  3. 1 1
      rest/server.go

+ 46 - 11
rest/internal/cors/handlers.go

@@ -9,19 +9,23 @@ const (
 	allowHeaders     = "Access-Control-Allow-Headers"
 	allowHeaders     = "Access-Control-Allow-Headers"
 	allowCredentials = "Access-Control-Allow-Credentials"
 	allowCredentials = "Access-Control-Allow-Credentials"
 	exposeHeaders    = "Access-Control-Expose-Headers"
 	exposeHeaders    = "Access-Control-Expose-Headers"
+	requestMethod    = "Access-Control-Request-Method"
+	requestHeaders   = "Access-Control-Request-Headers"
 	allowHeadersVal  = "Content-Type, Origin, X-CSRF-Token, Authorization, AccessToken, Token, Range"
 	allowHeadersVal  = "Content-Type, Origin, X-CSRF-Token, Authorization, AccessToken, Token, Range"
 	exposeHeadersVal = "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers"
 	exposeHeadersVal = "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers"
 	methods          = "GET, HEAD, POST, PATCH, PUT, DELETE"
 	methods          = "GET, HEAD, POST, PATCH, PUT, DELETE"
 	allowTrue        = "true"
 	allowTrue        = "true"
 	maxAgeHeader     = "Access-Control-Max-Age"
 	maxAgeHeader     = "Access-Control-Max-Age"
 	maxAgeHeaderVal  = "86400"
 	maxAgeHeaderVal  = "86400"
+	varyHeader       = "Vary"
+	originHeader     = "Origin"
 )
 )
 
 
-// Handler handles cross domain not allowed requests.
+// NotAllowedHandler handles cross domain not allowed requests.
 // At most one origin can be specified, other origins are ignored if given, default to be *.
 // At most one origin can be specified, other origins are ignored if given, default to be *.
-func Handler(origin ...string) http.Handler {
+func NotAllowedHandler(origins ...string) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		setHeader(w, getOrigin(origin))
+		checkAndSetHeaders(w, r, origins)
 
 
 		if r.Method != http.MethodOptions {
 		if r.Method != http.MethodOptions {
 			w.WriteHeader(http.StatusNotFound)
 			w.WriteHeader(http.StatusNotFound)
@@ -32,10 +36,10 @@ func Handler(origin ...string) http.Handler {
 }
 }
 
 
 // Middleware returns a middleware that adds CORS headers to the response.
 // Middleware returns a middleware that adds CORS headers to the response.
-func Middleware(origin ...string) func(http.HandlerFunc) http.HandlerFunc {
+func Middleware(origins ...string) func(http.HandlerFunc) http.HandlerFunc {
 	return func(next http.HandlerFunc) http.HandlerFunc {
 	return func(next http.HandlerFunc) http.HandlerFunc {
 		return func(w http.ResponseWriter, r *http.Request) {
 		return func(w http.ResponseWriter, r *http.Request) {
-			setHeader(w, getOrigin(origin))
+			checkAndSetHeaders(w, r, origins)
 
 
 			if r.Method == http.MethodOptions {
 			if r.Method == http.MethodOptions {
 				w.WriteHeader(http.StatusNoContent)
 				w.WriteHeader(http.StatusNoContent)
@@ -46,12 +50,32 @@ func Middleware(origin ...string) func(http.HandlerFunc) http.HandlerFunc {
 	}
 	}
 }
 }
 
 
-func getOrigin(origins []string) string {
-	if len(origins) > 0 {
-		return origins[0]
-	} else {
-		return allOrigins
+func checkAndSetHeaders(w http.ResponseWriter, r *http.Request, origins []string) {
+	setVaryHeaders(w, r)
+
+	if len(origins) == 0 {
+		setHeader(w, allOrigins)
+		return
+	}
+
+	origin := r.Header.Get(originHeader)
+	if isOriginAllowed(origins, origin) {
+		setHeader(w, origin)
+	}
+}
+
+func isOriginAllowed(allows []string, origin string) bool {
+	for _, o := range allows {
+		if o == allOrigins {
+			return true
+		}
+
+		if o == origin {
+			return true
+		}
 	}
 	}
+
+	return false
 }
 }
 
 
 func setHeader(w http.ResponseWriter, origin string) {
 func setHeader(w http.ResponseWriter, origin string) {
@@ -59,6 +83,17 @@ func setHeader(w http.ResponseWriter, origin string) {
 	w.Header().Set(allowMethods, methods)
 	w.Header().Set(allowMethods, methods)
 	w.Header().Set(allowHeaders, allowHeadersVal)
 	w.Header().Set(allowHeaders, allowHeadersVal)
 	w.Header().Set(exposeHeaders, exposeHeadersVal)
 	w.Header().Set(exposeHeaders, exposeHeadersVal)
-	w.Header().Set(allowCredentials, allowTrue)
+	if origin != allOrigins {
+		w.Header().Set(allowCredentials, allowTrue)
+	}
 	w.Header().Set(maxAgeHeader, maxAgeHeaderVal)
 	w.Header().Set(maxAgeHeader, maxAgeHeaderVal)
 }
 }
+
+func setVaryHeaders(w http.ResponseWriter, r *http.Request) {
+	header := w.Header()
+	header.Add(varyHeader, originHeader)
+	if r.Method == http.MethodOptions {
+		header.Add(varyHeader, requestMethod)
+		header.Add(varyHeader, requestHeaders)
+	}
+}

+ 31 - 10
rest/internal/cors/handlers_test.go

@@ -10,23 +10,42 @@ import (
 
 
 func TestCorsHandlerWithOrigins(t *testing.T) {
 func TestCorsHandlerWithOrigins(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
-		name    string
-		origins []string
-		expect  string
+		name      string
+		origins   []string
+		reqOrigin string
+		expect    string
 	}{
 	}{
 		{
 		{
 			name:   "allow all origins",
 			name:   "allow all origins",
 			expect: allOrigins,
 			expect: allOrigins,
 		},
 		},
 		{
 		{
-			name:    "allow one origin",
-			origins: []string{"local"},
-			expect:  "local",
+			name:      "allow one origin",
+			origins:   []string{"http://local"},
+			reqOrigin: "http://local",
+			expect:    "http://local",
 		},
 		},
 		{
 		{
-			name:    "allow many origins",
-			origins: []string{"local", "remote"},
-			expect:  "local",
+			name:      "allow many origins",
+			origins:   []string{"http://local", "http://remote"},
+			reqOrigin: "http://local",
+			expect:    "http://local",
+		},
+		{
+			name:      "allow all origins",
+			reqOrigin: "http://local",
+			expect:    "*",
+		},
+		{
+			name:      "allow many origins with all mark",
+			origins:   []string{"http://local", "http://remote", "*"},
+			reqOrigin: "http://another",
+			expect:    "http://another",
+		},
+		{
+			name:      "not allow origin",
+			origins:   []string{"http://local", "http://remote"},
+			reqOrigin: "http://another",
 		},
 		},
 	}
 	}
 
 
@@ -41,8 +60,9 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
 			test := test
 			test := test
 			t.Run(test.name+"-handler", func(t *testing.T) {
 			t.Run(test.name+"-handler", func(t *testing.T) {
 				r := httptest.NewRequest(method, "http://localhost", nil)
 				r := httptest.NewRequest(method, "http://localhost", nil)
+				r.Header.Set(originHeader, test.reqOrigin)
 				w := httptest.NewRecorder()
 				w := httptest.NewRecorder()
-				handler := Handler(test.origins...)
+				handler := NotAllowedHandler(test.origins...)
 				handler.ServeHTTP(w, r)
 				handler.ServeHTTP(w, r)
 				if method == http.MethodOptions {
 				if method == http.MethodOptions {
 					assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
 					assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
@@ -59,6 +79,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
 			test := test
 			test := test
 			t.Run(test.name+"-middleware", func(t *testing.T) {
 			t.Run(test.name+"-middleware", func(t *testing.T) {
 				r := httptest.NewRequest(method, "http://localhost", nil)
 				r := httptest.NewRequest(method, "http://localhost", nil)
+				r.Header.Set(originHeader, test.reqOrigin)
 				w := httptest.NewRecorder()
 				w := httptest.NewRecorder()
 				handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
 				handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
 					w.WriteHeader(http.StatusOK)
 					w.WriteHeader(http.StatusOK)

+ 1 - 1
rest/server.go

@@ -99,7 +99,7 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
 // WithCors returns a func to enable CORS for given origin, or default to all origins (*).
 // WithCors returns a func to enable CORS for given origin, or default to all origins (*).
 func WithCors(origin ...string) RunOption {
 func WithCors(origin ...string) RunOption {
 	return func(server *Server) {
 	return func(server *Server) {
-		server.router.SetNotAllowedHandler(cors.Handler(origin...))
+		server.router.SetNotAllowedHandler(cors.NotAllowedHandler(origin...))
 		server.Use(cors.Middleware(origin...))
 		server.Use(cors.Middleware(origin...))
 	}
 	}
 }
 }