|
@@ -8,10 +8,10 @@ import (
|
|
"sort"
|
|
"sort"
|
|
"time"
|
|
"time"
|
|
|
|
|
|
- "github.com/justinas/alice"
|
|
|
|
"github.com/zeromicro/go-zero/core/codec"
|
|
"github.com/zeromicro/go-zero/core/codec"
|
|
"github.com/zeromicro/go-zero/core/load"
|
|
"github.com/zeromicro/go-zero/core/load"
|
|
"github.com/zeromicro/go-zero/core/stat"
|
|
"github.com/zeromicro/go-zero/core/stat"
|
|
|
|
+ "github.com/zeromicro/go-zero/rest/chain"
|
|
"github.com/zeromicro/go-zero/rest/handler"
|
|
"github.com/zeromicro/go-zero/rest/handler"
|
|
"github.com/zeromicro/go-zero/rest/httpx"
|
|
"github.com/zeromicro/go-zero/rest/httpx"
|
|
"github.com/zeromicro/go-zero/rest/internal"
|
|
"github.com/zeromicro/go-zero/rest/internal"
|
|
@@ -25,15 +25,15 @@ const topCpuUsage = 1000
|
|
var ErrSignatureConfig = errors.New("bad config for Signature")
|
|
var ErrSignatureConfig = errors.New("bad config for Signature")
|
|
|
|
|
|
type engine struct {
|
|
type engine struct {
|
|
- conf RestConf
|
|
|
|
- routes []featuredRoutes
|
|
|
|
- unauthorizedCallback handler.UnauthorizedCallback
|
|
|
|
- unsignedCallback handler.UnsignedCallback
|
|
|
|
- disableDefaultMiddlewares bool
|
|
|
|
- middlewares []Middleware
|
|
|
|
- shedder load.Shedder
|
|
|
|
- priorityShedder load.Shedder
|
|
|
|
- tlsConfig *tls.Config
|
|
|
|
|
|
+ conf RestConf
|
|
|
|
+ routes []featuredRoutes
|
|
|
|
+ unauthorizedCallback handler.UnauthorizedCallback
|
|
|
|
+ unsignedCallback handler.UnsignedCallback
|
|
|
|
+ chain chain.Chain
|
|
|
|
+ middlewares []Middleware
|
|
|
|
+ shedder load.Shedder
|
|
|
|
+ priorityShedder load.Shedder
|
|
|
|
+ tlsConfig *tls.Config
|
|
}
|
|
}
|
|
|
|
|
|
func newEngine(c RestConf) *engine {
|
|
func newEngine(c RestConf) *engine {
|
|
@@ -53,20 +53,20 @@ func (ng *engine) addRoutes(r featuredRoutes) {
|
|
ng.routes = append(ng.routes, r)
|
|
ng.routes = append(ng.routes, r)
|
|
}
|
|
}
|
|
|
|
|
|
-func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
|
|
|
|
- verifier func(alice.Chain) alice.Chain) alice.Chain {
|
|
|
|
|
|
+func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
|
|
|
+ verifier func(chain.Chain) chain.Chain) chain.Chain {
|
|
if fr.jwt.enabled {
|
|
if fr.jwt.enabled {
|
|
if len(fr.jwt.prevSecret) == 0 {
|
|
if len(fr.jwt.prevSecret) == 0 {
|
|
- chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
|
|
|
|
|
+ chn = chn.Append(handler.Authorize(fr.jwt.secret,
|
|
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
|
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
|
} else {
|
|
} else {
|
|
- chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
|
|
|
|
|
+ chn = chn.Append(handler.Authorize(fr.jwt.secret,
|
|
handler.WithPrevSecret(fr.jwt.prevSecret),
|
|
handler.WithPrevSecret(fr.jwt.prevSecret),
|
|
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
|
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- return verifier(chain)
|
|
|
|
|
|
+ return verifier(chn)
|
|
}
|
|
}
|
|
|
|
|
|
func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
|
|
func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
|
|
@@ -85,10 +85,10 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
|
|
}
|
|
}
|
|
|
|
|
|
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
|
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
|
- route Route, verifier func(chain alice.Chain) alice.Chain) error {
|
|
|
|
- var chain alice.Chain
|
|
|
|
- if !ng.disableDefaultMiddlewares {
|
|
|
|
- chain = alice.New(
|
|
|
|
|
|
+ route Route, verifier func(chain.Chain) chain.Chain) error {
|
|
|
|
+ chn := ng.chain
|
|
|
|
+ if chn == nil {
|
|
|
|
+ chn = chain.New(
|
|
handler.TracingHandler(ng.conf.Name, route.Path),
|
|
handler.TracingHandler(ng.conf.Name, route.Path),
|
|
ng.getLogHandler(),
|
|
ng.getLogHandler(),
|
|
handler.PrometheusHandler(route.Path),
|
|
handler.PrometheusHandler(route.Path),
|
|
@@ -103,11 +103,12 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
|
|
)
|
|
)
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ chn = ng.appendAuthHandler(fr, chn, verifier)
|
|
|
|
+
|
|
for _, middleware := range ng.middlewares {
|
|
for _, middleware := range ng.middlewares {
|
|
- chain = chain.Append(convertMiddleware(middleware))
|
|
|
|
|
|
+ chn = chn.Append(convertMiddleware(middleware))
|
|
}
|
|
}
|
|
- chain = ng.appendAuthHandler(fr, chain, verifier)
|
|
|
|
- handle := chain.ThenFunc(route.Handler)
|
|
|
|
|
|
+ handle := chn.ThenFunc(route.Handler)
|
|
|
|
|
|
return router.Handle(route.Method, route.Path, handle)
|
|
return router.Handle(route.Method, route.Path, handle)
|
|
}
|
|
}
|
|
@@ -171,16 +172,16 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
|
|
// notFoundHandler returns a middleware that handles 404 not found requests.
|
|
// notFoundHandler returns a middleware that handles 404 not found requests.
|
|
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
|
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
- chain := alice.New(
|
|
|
|
|
|
+ chn := chain.New(
|
|
handler.TracingHandler(ng.conf.Name, ""),
|
|
handler.TracingHandler(ng.conf.Name, ""),
|
|
ng.getLogHandler(),
|
|
ng.getLogHandler(),
|
|
)
|
|
)
|
|
|
|
|
|
var h http.Handler
|
|
var h http.Handler
|
|
if next != nil {
|
|
if next != nil {
|
|
- h = chain.Then(next)
|
|
|
|
|
|
+ h = chn.Then(next)
|
|
} else {
|
|
} else {
|
|
- h = chain.Then(http.NotFoundHandler())
|
|
|
|
|
|
+ h = chn.Then(http.NotFoundHandler())
|
|
}
|
|
}
|
|
|
|
|
|
cw := response.NewHeaderOnceResponseWriter(w)
|
|
cw := response.NewHeaderOnceResponseWriter(w)
|
|
@@ -218,10 +219,10 @@ func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
|
|
ng.unsignedCallback = callback
|
|
ng.unsignedCallback = callback
|
|
}
|
|
}
|
|
|
|
|
|
-func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
|
|
|
|
|
|
+func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
|
|
if !signature.enabled {
|
|
if !signature.enabled {
|
|
- return func(chain alice.Chain) alice.Chain {
|
|
|
|
- return chain
|
|
|
|
|
|
+ return func(chn chain.Chain) chain.Chain {
|
|
|
|
+ return chn
|
|
}, nil
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
|
|
@@ -230,8 +231,8 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
|
|
return nil, ErrSignatureConfig
|
|
return nil, ErrSignatureConfig
|
|
}
|
|
}
|
|
|
|
|
|
- return func(chain alice.Chain) alice.Chain {
|
|
|
|
- return chain
|
|
|
|
|
|
+ return func(chn chain.Chain) chain.Chain {
|
|
|
|
+ return chn
|
|
}, nil
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
|
|
@@ -247,14 +248,13 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
|
|
decrypters[fingerprint] = decrypter
|
|
decrypters[fingerprint] = decrypter
|
|
}
|
|
}
|
|
|
|
|
|
- return func(chain alice.Chain) alice.Chain {
|
|
|
|
|
|
+ return func(chn chain.Chain) chain.Chain {
|
|
if ng.unsignedCallback != nil {
|
|
if ng.unsignedCallback != nil {
|
|
- return chain.Append(handler.ContentSecurityHandler(
|
|
|
|
|
|
+ return chn.Append(handler.ContentSecurityHandler(
|
|
decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
|
|
decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
|
|
}
|
|
}
|
|
|
|
|
|
- return chain.Append(handler.ContentSecurityHandler(
|
|
|
|
- decrypters, signature.Expiry, signature.Strict))
|
|
|
|
|
|
+ return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict))
|
|
}, nil
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
|