ソースを参照

chore: only allow cors middleware to change headers (#1276)

Kevin Wan 3 年 前
コミット
3dda557410

+ 2 - 2
rest/internal/cors/handlers.go

@@ -45,12 +45,12 @@ func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.H
 }
 
 // Middleware returns a middleware that adds CORS headers to the response.
-func Middleware(fn func(w http.ResponseWriter), origins ...string) func(http.HandlerFunc) http.HandlerFunc {
+func Middleware(fn func(w http.Header), origins ...string) func(http.HandlerFunc) http.HandlerFunc {
 	return func(next http.HandlerFunc) http.HandlerFunc {
 		return func(w http.ResponseWriter, r *http.Request) {
 			checkAndSetHeaders(w, r, origins)
 			if fn != nil {
-				fn(w)
+				fn(w.Header())
 			}
 
 			if r.Method == http.MethodOptions {

+ 2 - 2
rest/internal/cors/handlers_test.go

@@ -114,8 +114,8 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
 				r := httptest.NewRequest(method, "http://localhost", nil)
 				r.Header.Set(originHeader, test.reqOrigin)
 				w := httptest.NewRecorder()
-				handler := Middleware(func(w http.ResponseWriter) {
-					w.Header().Set("foo", "bar")
+				handler := Middleware(func(header http.Header) {
+					header.Set("foo", "bar")
 				}, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
 					w.WriteHeader(http.StatusOK)
 				})

+ 4 - 3
rest/server.go

@@ -106,10 +106,11 @@ func WithCors(origin ...string) RunOption {
 
 // WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*),
 // fn lets caller customizing the response.
-func WithCustomCors(fn func(http.ResponseWriter), origin ...string) RunOption {
+func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(http.ResponseWriter),
+	origin ...string) RunOption {
 	return func(server *Server) {
-		server.router.SetNotAllowedHandler(cors.NotAllowedHandler(fn, origin...))
-		server.Use(cors.Middleware(fn, origin...))
+		server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...))
+		server.Use(cors.Middleware(middlewareFn, origin...))
 	}
 }
 

+ 4 - 2
rest/server_test.go

@@ -322,8 +322,10 @@ Port: 54321
 	srv, err := NewServer(cnf, WithRouter(rt))
 	assert.Nil(t, err)
 
-	opt := WithCustomCors(func(w http.ResponseWriter) {
-		w.Header().Set("foo", "bar")
+	opt := WithCustomCors(func(header http.Header) {
+		header.Set("foo", "bar")
+	}, func(w http.ResponseWriter) {
+		w.WriteHeader(http.StatusOK)
 	}, "local")
 	opt(srv)
 }