浏览代码

add user middleware chain function (#1913)

* add user middleware chain function

* fix staticcheck SA4006

* chang code Implementation style

Co-authored-by: kemq1 <kemq1@spdb.com.cn>
magickeha 2 年之前
父节点
当前提交
6976ba7e13
共有 3 个文件被更改,包括 90 次插入13 次删除
  1. 24 13
      rest/engine.go
  2. 54 0
      rest/engine_test.go
  3. 12 0
      rest/server.go

+ 24 - 13
rest/engine.go

@@ -33,6 +33,7 @@ type engine struct {
 	shedder              load.Shedder
 	priorityShedder      load.Shedder
 	tlsConfig            *tls.Config
+	chain                *alice.Chain
 }
 
 func newEngine(c RestConf) *engine {
@@ -85,19 +86,25 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
 
 func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
 	route Route, verifier func(chain alice.Chain) alice.Chain) error {
-	chain := alice.New(
-		handler.TracingHandler(ng.conf.Name, route.Path),
-		ng.getLogHandler(),
-		handler.PrometheusHandler(route.Path),
-		handler.MaxConns(ng.conf.MaxConns),
-		handler.BreakerHandler(route.Method, route.Path, metrics),
-		handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
-		handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
-		handler.RecoverHandler,
-		handler.MetricHandler(metrics),
-		handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
-		handler.GunzipHandler,
-	)
+	var chain alice.Chain
+	if ng.chain == nil {
+		chain = alice.New(
+			handler.TracingHandler(ng.conf.Name, route.Path),
+			ng.getLogHandler(),
+			handler.PrometheusHandler(route.Path),
+			handler.MaxConns(ng.conf.MaxConns),
+			handler.BreakerHandler(route.Method, route.Path, metrics),
+			handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
+			handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
+			handler.RecoverHandler,
+			handler.MetricHandler(metrics),
+			handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
+			handler.GunzipHandler,
+		)
+	} else {
+		chain = *ng.chain
+	}
+
 	chain = ng.appendAuthHandler(fr, chain, verifier)
 
 	for _, middleware := range ng.middlewares {
@@ -206,6 +213,10 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) {
 	ng.tlsConfig = cfg
 }
 
+func (ng *engine) setChainConfig(chain *alice.Chain) {
+	ng.chain = chain
+}
+
 func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
 	ng.unauthorizedCallback = callback
 }

+ 54 - 0
rest/engine_test.go

@@ -229,6 +229,44 @@ func TestEngine_checkedMaxBytes(t *testing.T) {
 	}
 }
 
+func TestEngine_checkedChain(t *testing.T) {
+	var called int32
+	middleware1 := func() func(http.Handler) http.Handler {
+		return func(next http.Handler) http.Handler {
+			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				atomic.AddInt32(&called, 1)
+				next.ServeHTTP(w, r)
+				atomic.AddInt32(&called, 1)
+			})
+		}
+	}
+	middleware2 := func() func(http.Handler) http.Handler {
+		return func(next http.Handler) http.Handler {
+			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				atomic.AddInt32(&called, 1)
+				next.ServeHTTP(w, r)
+				atomic.AddInt32(&called, 1)
+			})
+		}
+	}
+
+	server := MustNewServer(RestConf{}, WithChain(middleware1(), middleware2()))
+	server.router = chainRouter{}
+	server.AddRoutes(
+		[]Route{
+			{
+				Method: http.MethodGet,
+				Path:   "/",
+				Handler: func(_ http.ResponseWriter, _ *http.Request) {
+					atomic.AddInt32(&called, 1)
+				},
+			},
+		},
+	)
+	server.ngin.bindRoutes(chainRouter{})
+	assert.Equal(t, int32(5), atomic.LoadInt32(&called))
+}
+
 func TestEngine_notFoundHandler(t *testing.T) {
 	logx.Disable()
 
@@ -343,3 +381,19 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
 
 func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
 }
+
+type chainRouter struct{}
+
+func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
+}
+
+func (c chainRouter) Handle(_, _ string, handler http.Handler) error {
+	handler.ServeHTTP(nil, nil)
+	return nil
+}
+
+func (c chainRouter) SetNotFoundHandler(_ http.Handler) {
+}
+
+func (c chainRouter) SetNotAllowedHandler(_ http.Handler) {
+}

+ 12 - 0
rest/server.go

@@ -7,6 +7,7 @@ import (
 	"path"
 	"time"
 
+	"github.com/justinas/alice"
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/rest/handler"
 	"github.com/zeromicro/go-zero/rest/httpx"
@@ -242,6 +243,17 @@ func WithTLSConfig(cfg *tls.Config) RunOption {
 	}
 }
 
+// WithChain returns a RunOption that with given chain config.
+func WithChain(middlewares ...func(http.Handler) http.Handler) RunOption {
+	return func(svr *Server) {
+		chain := alice.New()
+		for _, middleware := range middlewares {
+			chain = chain.Append(middleware)
+		}
+		svr.ngin.setChainConfig(&chain)
+	}
+}
+
 // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
 func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
 	return func(svr *Server) {