Browse Source

chore: refactor to simplify disabling builtin middlewares (#2031)

* chore: refactor to simplify disabling builtin middlewares

* chore: rename methods
Kevin Wan 2 years ago
parent
commit
018ca82048
4 changed files with 48 additions and 31 deletions
  1. 11 18
      rest/engine.go
  2. 3 1
      rest/engine_test.go
  3. 7 12
      rest/server.go
  4. 27 0
      rest/server_test.go

+ 11 - 18
rest/engine.go

@@ -25,15 +25,15 @@ const topCpuUsage = 1000
 var ErrSignatureConfig = errors.New("bad config for Signature")
 
 type engine struct {
-	conf                 RestConf
-	routes               []featuredRoutes
-	unauthorizedCallback handler.UnauthorizedCallback
-	unsignedCallback     handler.UnsignedCallback
-	middlewares          []Middleware
-	shedder              load.Shedder
-	priorityShedder      load.Shedder
-	tlsConfig            *tls.Config
-	chain                *alice.Chain
+	conf                      RestConf
+	routes                    []featuredRoutes
+	unauthorizedCallback      handler.UnauthorizedCallback
+	unsignedCallback          handler.UnsignedCallback
+	disableDefaultMiddlewares bool
+	middlewares               []Middleware
+	shedder                   load.Shedder
+	priorityShedder           load.Shedder
+	tlsConfig                 *tls.Config
 }
 
 func newEngine(c RestConf) *engine {
@@ -87,7 +87,7 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
 func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
 	route Route, verifier func(chain alice.Chain) alice.Chain) error {
 	var chain alice.Chain
-	if ng.chain == nil {
+	if !ng.disableDefaultMiddlewares {
 		chain = alice.New(
 			handler.TracingHandler(ng.conf.Name, route.Path),
 			ng.getLogHandler(),
@@ -101,15 +101,12 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
 			handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
 			handler.GunzipHandler,
 		)
-	} else {
-		chain = *ng.chain
 	}
 
-	chain = ng.appendAuthHandler(fr, chain, verifier)
-
 	for _, middleware := range ng.middlewares {
 		chain = chain.Append(convertMiddleware(middleware))
 	}
+	chain = ng.appendAuthHandler(fr, chain, verifier)
 	handle := chain.ThenFunc(route.Handler)
 
 	return router.Handle(route.Method, route.Path, handle)
@@ -213,10 +210,6 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) {
 	ng.tlsConfig = cfg
 }
 
-func (ng *engine) setChainConfig(chain *alice.Chain) {
-	ng.chain = chain
-}
-
 func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
 	ng.unauthorizedCallback = callback
 }

+ 3 - 1
rest/engine_test.go

@@ -250,7 +250,9 @@ func TestEngine_checkedChain(t *testing.T) {
 		}
 	}
 
-	server := MustNewServer(RestConf{}, WithChain(middleware1(), middleware2()))
+	server := MustNewServer(RestConf{}, DisableDefaultMiddlewares())
+	server.Use(ToMiddleware(middleware1()))
+	server.Use(ToMiddleware(middleware2()))
 	server.router = chainRouter{}
 	server.AddRoutes(
 		[]Route{

+ 7 - 12
rest/server.go

@@ -7,7 +7,6 @@ import (
 	"path"
 	"time"
 
-	"github.com/justinas/alice"
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/rest/handler"
 	"github.com/zeromicro/go-zero/rest/httpx"
@@ -96,6 +95,13 @@ func (s *Server) Use(middleware Middleware) {
 	s.ngin.use(middleware)
 }
 
+// DisableDefaultMiddlewares returns a RunOption that disables the builtin middlewares.
+func DisableDefaultMiddlewares() RunOption {
+	return func(svr *Server) {
+		svr.ngin.disableDefaultMiddlewares = true
+	}
+}
+
 // ToMiddleware converts the given handler to a Middleware.
 func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
 	return func(handle http.HandlerFunc) http.HandlerFunc {
@@ -243,17 +249,6 @@ func WithTLSConfig(cfg *tls.Config) RunOption {
 	}
 }
 
-// WithChain returns a RunOption that with given chain config.
-func WithChain(middlewares ...func(http.Handler) http.Handler) RunOption {
-	return func(svr *Server) {
-		chain := alice.New()
-		for _, middleware := range middlewares {
-			chain = chain.Append(middleware)
-		}
-		svr.ngin.setChainConfig(&chain)
-	}
-}
-
 // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
 func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
 	return func(svr *Server) {

+ 27 - 0
rest/server_test.go

@@ -15,6 +15,7 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/zeromicro/go-zero/core/conf"
 	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/core/service"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/router"
 )
@@ -102,6 +103,18 @@ Port: 54321
 	}
 }
 
+func TestNewServerError(t *testing.T) {
+	_, err := NewServer(RestConf{
+		ServiceConf: service.ServiceConf{
+			Log: logx.LogConf{
+				// file mode, no path specified
+				Mode: "file",
+			},
+		},
+	})
+	assert.NotNil(t, err)
+}
+
 func TestWithMaxBytes(t *testing.T) {
 	const maxBytes = 1000
 	var fr featuredRoutes
@@ -320,6 +333,7 @@ Port: 54321
 	rt := router.NewRouter()
 	svr, err := NewServer(cnf, WithRouter(rt))
 	assert.Nil(t, err)
+	defer svr.Stop()
 
 	opt := WithCors("local")
 	opt(svr)
@@ -408,3 +422,16 @@ Port: 54321
 	out := <-ch
 	assert.Equal(t, expect, out)
 }
+
+func TestHandleError(t *testing.T) {
+	assert.NotPanics(t, func() {
+		handleError(nil)
+		handleError(http.ErrServerClosed)
+	})
+}
+
+func TestValidateSecret(t *testing.T) {
+	assert.Panics(t, func() {
+		validateSecret("short")
+	})
+}