|
@@ -126,7 +126,7 @@ func WithChain(chn chain.Chain) RunOption {
|
|
func WithCors(origin ...string) RunOption {
|
|
func WithCors(origin ...string) RunOption {
|
|
return func(server *Server) {
|
|
return func(server *Server) {
|
|
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...))
|
|
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...))
|
|
- server.Use(cors.Middleware(nil, origin...))
|
|
|
|
|
|
+ server.router = newCorsRouter(server.router, nil, origin...)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -136,7 +136,7 @@ func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(htt
|
|
origin ...string) RunOption {
|
|
origin ...string) RunOption {
|
|
return func(server *Server) {
|
|
return func(server *Server) {
|
|
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...))
|
|
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...))
|
|
- server.Use(cors.Middleware(middlewareFn, origin...))
|
|
|
|
|
|
+ server.router = newCorsRouter(server.router, middlewareFn, origin...)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -291,3 +291,19 @@ func validateSecret(secret string) {
|
|
panic("secret's length can't be less than 8")
|
|
panic("secret's length can't be less than 8")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+type corsRouter struct {
|
|
|
|
+ httpx.Router
|
|
|
|
+ middleware Middleware
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func newCorsRouter(router httpx.Router, headerFn func(http.Header), origins ...string) httpx.Router {
|
|
|
|
+ return &corsRouter{
|
|
|
|
+ Router: router,
|
|
|
|
+ middleware: cors.Middleware(headerFn, origins...),
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (c *corsRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
|
+ c.middleware(c.Router.ServeHTTP)(w, r)
|
|
|
|
+}
|