Browse Source

feat: add rest.WithCustomCors to let caller customize the response (#1274)

Kevin Wan 3 years ago
parent
commit
0395ba1816
4 changed files with 72 additions and 6 deletions
  1. 8 2
      rest/internal/cors/handlers.go
  2. 36 2
      rest/internal/cors/handlers_test.go
  3. 11 2
      rest/server.go
  4. 17 0
      rest/server_test.go

+ 8 - 2
rest/internal/cors/handlers.go

@@ -23,9 +23,12 @@ const (
 
 // NotAllowedHandler handles cross domain not allowed requests.
 // At most one origin can be specified, other origins are ignored if given, default to be *.
-func NotAllowedHandler(origins ...string) http.Handler {
+func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		checkAndSetHeaders(w, r, origins)
+		if fn != nil {
+			fn(w)
+		}
 
 		if r.Method != http.MethodOptions {
 			w.WriteHeader(http.StatusNotFound)
@@ -36,10 +39,13 @@ func NotAllowedHandler(origins ...string) http.Handler {
 }
 
 // Middleware returns a middleware that adds CORS headers to the response.
-func Middleware(origins ...string) func(http.HandlerFunc) http.HandlerFunc {
+func Middleware(fn func(w http.ResponseWriter), origins ...string) func(http.HandlerFunc) http.HandlerFunc {
 	return func(next http.HandlerFunc) http.HandlerFunc {
 		return func(w http.ResponseWriter, r *http.Request) {
 			checkAndSetHeaders(w, r, origins)
+			if fn != nil {
+				fn(w)
+			}
 
 			if r.Method == http.MethodOptions {
 				w.WriteHeader(http.StatusNoContent)

+ 36 - 2
rest/internal/cors/handlers_test.go

@@ -62,7 +62,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
 				r := httptest.NewRequest(method, "http://localhost", nil)
 				r.Header.Set(originHeader, test.reqOrigin)
 				w := httptest.NewRecorder()
-				handler := NotAllowedHandler(test.origins...)
+				handler := NotAllowedHandler(nil, test.origins...)
 				handler.ServeHTTP(w, r)
 				if method == http.MethodOptions {
 					assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
@@ -71,6 +71,22 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
 				}
 				assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
 			})
+			t.Run(test.name+"-handler-custom", func(t *testing.T) {
+				r := httptest.NewRequest(method, "http://localhost", nil)
+				r.Header.Set(originHeader, test.reqOrigin)
+				w := httptest.NewRecorder()
+				handler := NotAllowedHandler(func(w http.ResponseWriter) {
+					w.Header().Set("foo", "bar")
+				}, test.origins...)
+				handler.ServeHTTP(w, r)
+				if method == http.MethodOptions {
+					assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
+				} else {
+					assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
+				}
+				assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
+				assert.Equal(t, "bar", w.Header().Get("foo"))
+			})
 		}
 	}
 
@@ -81,7 +97,24 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
 				r := httptest.NewRequest(method, "http://localhost", nil)
 				r.Header.Set(originHeader, test.reqOrigin)
 				w := httptest.NewRecorder()
-				handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
+				handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
+					w.WriteHeader(http.StatusOK)
+				})
+				handler.ServeHTTP(w, r)
+				if method == http.MethodOptions {
+					assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
+				} else {
+					assert.Equal(t, http.StatusOK, w.Result().StatusCode)
+				}
+				assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
+			})
+			t.Run(test.name+"-middleware-custom", func(t *testing.T) {
+				r := httptest.NewRequest(method, "http://localhost", nil)
+				r.Header.Set(originHeader, test.reqOrigin)
+				w := httptest.NewRecorder()
+				handler := Middleware(func(w http.ResponseWriter) {
+					w.Header().Set("foo", "bar")
+				}, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
 					w.WriteHeader(http.StatusOK)
 				})
 				handler.ServeHTTP(w, r)
@@ -91,6 +124,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
 					assert.Equal(t, http.StatusOK, w.Result().StatusCode)
 				}
 				assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
+				assert.Equal(t, "bar", w.Header().Get("foo"))
 			})
 		}
 	}

+ 11 - 2
rest/server.go

@@ -99,8 +99,17 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
 // WithCors returns a func to enable CORS for given origin, or default to all origins (*).
 func WithCors(origin ...string) RunOption {
 	return func(server *Server) {
-		server.router.SetNotAllowedHandler(cors.NotAllowedHandler(origin...))
-		server.Use(cors.Middleware(origin...))
+		server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...))
+		server.Use(cors.Middleware(nil, origin...))
+	}
+}
+
+// WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*),
+// fn lets caller customizing the response.
+func WithCustomCors(fn func(http.ResponseWriter), origin ...string) RunOption {
+	return func(server *Server) {
+		server.router.SetNotAllowedHandler(cors.NotAllowedHandler(fn, origin...))
+		server.Use(cors.Middleware(fn, origin...))
 	}
 }
 

+ 17 - 0
rest/server_test.go

@@ -310,3 +310,20 @@ Port: 54321
 	opt := WithCors("local")
 	opt(srv)
 }
+
+func TestWithCustomCors(t *testing.T) {
+	const configYaml = `
+Name: foo
+Port: 54321
+`
+	var cnf RestConf
+	assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
+	rt := router.NewRouter()
+	srv, err := NewServer(cnf, WithRouter(rt))
+	assert.Nil(t, err)
+
+	opt := WithCustomCors(func(w http.ResponseWriter) {
+		w.Header().Set("foo", "bar")
+	}, "local")
+	opt(srv)
+}