|
@@ -62,7 +62,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
|
|
r := httptest.NewRequest(method, "http://localhost", nil)
|
|
|
r.Header.Set(originHeader, test.reqOrigin)
|
|
|
w := httptest.NewRecorder()
|
|
|
- handler := NotAllowedHandler(test.origins...)
|
|
|
+ handler := NotAllowedHandler(nil, test.origins...)
|
|
|
handler.ServeHTTP(w, r)
|
|
|
if method == http.MethodOptions {
|
|
|
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
|
@@ -71,6 +71,22 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
|
|
}
|
|
|
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
|
|
})
|
|
|
+ t.Run(test.name+"-handler-custom", func(t *testing.T) {
|
|
|
+ r := httptest.NewRequest(method, "http://localhost", nil)
|
|
|
+ r.Header.Set(originHeader, test.reqOrigin)
|
|
|
+ w := httptest.NewRecorder()
|
|
|
+ handler := NotAllowedHandler(func(w http.ResponseWriter) {
|
|
|
+ w.Header().Set("foo", "bar")
|
|
|
+ }, test.origins...)
|
|
|
+ handler.ServeHTTP(w, r)
|
|
|
+ if method == http.MethodOptions {
|
|
|
+ assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
|
|
+ } else {
|
|
|
+ assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
|
|
|
+ }
|
|
|
+ assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
|
|
+ assert.Equal(t, "bar", w.Header().Get("foo"))
|
|
|
+ })
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -81,7 +97,24 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
|
|
r := httptest.NewRequest(method, "http://localhost", nil)
|
|
|
r.Header.Set(originHeader, test.reqOrigin)
|
|
|
w := httptest.NewRecorder()
|
|
|
- handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ w.WriteHeader(http.StatusOK)
|
|
|
+ })
|
|
|
+ handler.ServeHTTP(w, r)
|
|
|
+ if method == http.MethodOptions {
|
|
|
+ assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
|
|
+ } else {
|
|
|
+ assert.Equal(t, http.StatusOK, w.Result().StatusCode)
|
|
|
+ }
|
|
|
+ assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
|
|
+ })
|
|
|
+ t.Run(test.name+"-middleware-custom", func(t *testing.T) {
|
|
|
+ r := httptest.NewRequest(method, "http://localhost", nil)
|
|
|
+ r.Header.Set(originHeader, test.reqOrigin)
|
|
|
+ w := httptest.NewRecorder()
|
|
|
+ handler := Middleware(func(w http.ResponseWriter) {
|
|
|
+ w.Header().Set("foo", "bar")
|
|
|
+ }, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
})
|
|
|
handler.ServeHTTP(w, r)
|
|
@@ -91,6 +124,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
|
|
assert.Equal(t, http.StatusOK, w.Result().StatusCode)
|
|
|
}
|
|
|
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
|
|
+ assert.Equal(t, "bar", w.Header().Get("foo"))
|
|
|
})
|
|
|
}
|
|
|
}
|