Răsfoiți Sursa

feat: rest.WithChain to replace builtin middlewares (#2033)

* feat: rest.WithChain to replace builtin middlewares

* chore: add comments

* chore: refine code
Kevin Wan 2 ani în urmă
părinte
comite
47c49de94e
6 a modificat fișierele cu 322 adăugiri și 98 ștergeri
  1. 109 0
      rest/chain/chain.go
  2. 126 0
      rest/chain/chain_test.go
  3. 34 34
      rest/engine.go
  4. 1 57
      rest/engine_test.go
  5. 9 7
      rest/server.go
  6. 43 0
      rest/server_test.go

+ 109 - 0
rest/chain/chain.go

@@ -0,0 +1,109 @@
+package chain
+
+// This is a modified version of https://github.com/justinas/alice
+// The original code is licensed under the MIT license.
+// It's modified for couple reasons:
+// - Added the Chain interface
+// - Added support for the Chain.Prepend(...) method
+
+import "net/http"
+
+type (
+	// Chain defines a chain of middleware.
+	Chain interface {
+		Append(middlewares ...Middleware) Chain
+		Prepend(middlewares ...Middleware) Chain
+		Then(h http.Handler) http.Handler
+		ThenFunc(fn http.HandlerFunc) http.Handler
+	}
+
+	// Middleware is an HTTP middleware.
+	Middleware func(http.Handler) http.Handler
+
+	// chain acts as a list of http.Handler middlewares.
+	// chain is effectively immutable:
+	// once created, it will always hold
+	// the same set of middlewares in the same order.
+	chain struct {
+		middlewares []Middleware
+	}
+)
+
+// New creates a new Chain, memorizing the given list of middleware middlewares.
+// New serves no other function, middlewares are only called upon a call to Then() or ThenFunc().
+func New(middlewares ...Middleware) Chain {
+	return chain{middlewares: append(([]Middleware)(nil), middlewares...)}
+}
+
+// Append extends a chain, adding the specified middlewares as the last ones in the request flow.
+//
+//     c := chain.New(m1, m2)
+//     c.Append(m3, m4)
+//     // requests in c go m1 -> m2 -> m3 -> m4
+func (c chain) Append(middlewares ...Middleware) Chain {
+	return chain{middlewares: join(c.middlewares, middlewares)}
+}
+
+// Prepend extends a chain by adding the specified chain as the first one in the request flow.
+//
+//     c := chain.New(m3, m4)
+//     c1 := chain.New(m1, m2)
+//     c.Prepend(c1)
+//     // requests in c go m1 -> m2 -> m3 -> m4
+func (c chain) Prepend(middlewares ...Middleware) Chain {
+	return chain{middlewares: join(middlewares, c.middlewares)}
+}
+
+// Then chains the middleware and returns the final http.Handler.
+//     New(m1, m2, m3).Then(h)
+// is equivalent to:
+//     m1(m2(m3(h)))
+// When the request comes in, it will be passed to m1, then m2, then m3
+// and finally, the given handler
+// (assuming every middleware calls the following one).
+//
+// A chain can be safely reused by calling Then() several times.
+//     stdStack := chain.New(ratelimitHandler, csrfHandler)
+//     indexPipe = stdStack.Then(indexHandler)
+//     authPipe = stdStack.Then(authHandler)
+// Note that middlewares are called on every call to Then() or ThenFunc()
+// and thus several instances of the same middleware will be created
+// when a chain is reused in this way.
+// For proper middleware, this should cause no problems.
+//
+// Then() treats nil as http.DefaultServeMux.
+func (c chain) Then(h http.Handler) http.Handler {
+	if h == nil {
+		h = http.DefaultServeMux
+	}
+
+	for i := range c.middlewares {
+		h = c.middlewares[len(c.middlewares)-1-i](h)
+	}
+
+	return h
+}
+
+// ThenFunc works identically to Then, but takes
+// a HandlerFunc instead of a Handler.
+//
+// The following two statements are equivalent:
+//     c.Then(http.HandlerFunc(fn))
+//     c.ThenFunc(fn)
+//
+// ThenFunc provides all the guarantees of Then.
+func (c chain) ThenFunc(fn http.HandlerFunc) http.Handler {
+	// This nil check cannot be removed due to the "nil is not nil" common mistake in Go.
+	// Required due to: https://stackoverflow.com/questions/33426977/how-to-golang-check-a-variable-is-nil
+	if fn == nil {
+		return c.Then(nil)
+	}
+	return c.Then(fn)
+}
+
+func join(a, b []Middleware) []Middleware {
+	mids := make([]Middleware, 0, len(a)+len(b))
+	mids = append(mids, a...)
+	mids = append(mids, b...)
+	return mids
+}

+ 126 - 0
rest/chain/chain_test.go

@@ -0,0 +1,126 @@
+package chain
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"reflect"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+// A constructor for middleware
+// that writes its own "tag" into the RW and does nothing else.
+// Useful in checking if a chain is behaving in the right order.
+func tagMiddleware(tag string) Middleware {
+	return func(h http.Handler) http.Handler {
+		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			w.Write([]byte(tag))
+			h.ServeHTTP(w, r)
+		})
+	}
+}
+
+// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
+// but the best we can do.
+func funcsEqual(f1, f2 interface{}) bool {
+	val1 := reflect.ValueOf(f1)
+	val2 := reflect.ValueOf(f2)
+	return val1.Pointer() == val2.Pointer()
+}
+
+var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+	w.Write([]byte("app\n"))
+})
+
+func TestNew(t *testing.T) {
+	c1 := func(h http.Handler) http.Handler {
+		return nil
+	}
+
+	c2 := func(h http.Handler) http.Handler {
+		return http.StripPrefix("potato", nil)
+	}
+
+	slice := []Middleware{c1, c2}
+	c := New(slice...)
+	for k := range slice {
+		assert.True(t, funcsEqual(c.(chain).middlewares[k], slice[k]),
+			"New does not add constructors correctly")
+	}
+}
+
+func TestThenWorksWithNoMiddleware(t *testing.T) {
+	assert.True(t, funcsEqual(New().Then(testApp), testApp),
+		"Then does not work with no middleware")
+}
+
+func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
+	assert.Equal(t, http.DefaultServeMux, New().Then(nil),
+		"Then does not treat nil as DefaultServeMux")
+}
+
+func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
+	assert.Equal(t, http.DefaultServeMux, New().ThenFunc(nil),
+		"ThenFunc does not treat nil as DefaultServeMux")
+}
+
+func TestThenFuncConstructsHandlerFunc(t *testing.T) {
+	fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(200)
+	})
+	chained := New().ThenFunc(fn)
+	rec := httptest.NewRecorder()
+
+	chained.ServeHTTP(rec, (*http.Request)(nil))
+
+	assert.Equal(t, reflect.TypeOf((http.HandlerFunc)(nil)), reflect.TypeOf(chained),
+		"ThenFunc does not construct HandlerFunc")
+}
+
+func TestThenOrdersHandlersCorrectly(t *testing.T) {
+	t1 := tagMiddleware("t1\n")
+	t2 := tagMiddleware("t2\n")
+	t3 := tagMiddleware("t3\n")
+
+	chained := New(t1, t2, t3).Then(testApp)
+
+	w := httptest.NewRecorder()
+	r, err := http.NewRequest("GET", "/", nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	chained.ServeHTTP(w, r)
+
+	assert.Equal(t, "t1\nt2\nt3\napp\n", w.Body.String(),
+		"Then does not order handlers correctly")
+}
+
+func TestAppendAddsHandlersCorrectly(t *testing.T) {
+	c := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
+	c = c.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
+	h := c.Then(testApp)
+
+	w := httptest.NewRecorder()
+	r, err := http.NewRequest("GET", "/", nil)
+	assert.Nil(t, err)
+
+	h.ServeHTTP(w, r)
+	assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(),
+		"Append does not add handlers correctly")
+}
+
+func TestExtendAddsHandlersCorrectly(t *testing.T) {
+	c := New(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
+	c = c.Prepend(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
+	h := c.Then(testApp)
+
+	w := httptest.NewRecorder()
+	r, err := http.NewRequest("GET", "/", nil)
+	assert.Nil(t, err)
+
+	h.ServeHTTP(w, r)
+	assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(),
+		"Extend does not add handlers in correctly")
+}

+ 34 - 34
rest/engine.go

@@ -8,10 +8,10 @@ import (
 	"sort"
 	"sort"
 	"time"
 	"time"
 
 
-	"github.com/justinas/alice"
 	"github.com/zeromicro/go-zero/core/codec"
 	"github.com/zeromicro/go-zero/core/codec"
 	"github.com/zeromicro/go-zero/core/load"
 	"github.com/zeromicro/go-zero/core/load"
 	"github.com/zeromicro/go-zero/core/stat"
 	"github.com/zeromicro/go-zero/core/stat"
+	"github.com/zeromicro/go-zero/rest/chain"
 	"github.com/zeromicro/go-zero/rest/handler"
 	"github.com/zeromicro/go-zero/rest/handler"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/internal"
 	"github.com/zeromicro/go-zero/rest/internal"
@@ -25,15 +25,15 @@ const topCpuUsage = 1000
 var ErrSignatureConfig = errors.New("bad config for Signature")
 var ErrSignatureConfig = errors.New("bad config for Signature")
 
 
 type engine struct {
 type engine struct {
-	conf                      RestConf
-	routes                    []featuredRoutes
-	unauthorizedCallback      handler.UnauthorizedCallback
-	unsignedCallback          handler.UnsignedCallback
-	disableDefaultMiddlewares bool
-	middlewares               []Middleware
-	shedder                   load.Shedder
-	priorityShedder           load.Shedder
-	tlsConfig                 *tls.Config
+	conf                 RestConf
+	routes               []featuredRoutes
+	unauthorizedCallback handler.UnauthorizedCallback
+	unsignedCallback     handler.UnsignedCallback
+	chain                chain.Chain
+	middlewares          []Middleware
+	shedder              load.Shedder
+	priorityShedder      load.Shedder
+	tlsConfig            *tls.Config
 }
 }
 
 
 func newEngine(c RestConf) *engine {
 func newEngine(c RestConf) *engine {
@@ -53,20 +53,20 @@ func (ng *engine) addRoutes(r featuredRoutes) {
 	ng.routes = append(ng.routes, r)
 	ng.routes = append(ng.routes, r)
 }
 }
 
 
-func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
-	verifier func(alice.Chain) alice.Chain) alice.Chain {
+func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
+	verifier func(chain.Chain) chain.Chain) chain.Chain {
 	if fr.jwt.enabled {
 	if fr.jwt.enabled {
 		if len(fr.jwt.prevSecret) == 0 {
 		if len(fr.jwt.prevSecret) == 0 {
-			chain = chain.Append(handler.Authorize(fr.jwt.secret,
+			chn = chn.Append(handler.Authorize(fr.jwt.secret,
 				handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
 				handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
 		} else {
 		} else {
-			chain = chain.Append(handler.Authorize(fr.jwt.secret,
+			chn = chn.Append(handler.Authorize(fr.jwt.secret,
 				handler.WithPrevSecret(fr.jwt.prevSecret),
 				handler.WithPrevSecret(fr.jwt.prevSecret),
 				handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
 				handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
 		}
 		}
 	}
 	}
 
 
-	return verifier(chain)
+	return verifier(chn)
 }
 }
 
 
 func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
 func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
@@ -85,10 +85,10 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
 }
 }
 
 
 func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
 func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
-	route Route, verifier func(chain alice.Chain) alice.Chain) error {
-	var chain alice.Chain
-	if !ng.disableDefaultMiddlewares {
-		chain = alice.New(
+	route Route, verifier func(chain.Chain) chain.Chain) error {
+	chn := ng.chain
+	if chn == nil {
+		chn = chain.New(
 			handler.TracingHandler(ng.conf.Name, route.Path),
 			handler.TracingHandler(ng.conf.Name, route.Path),
 			ng.getLogHandler(),
 			ng.getLogHandler(),
 			handler.PrometheusHandler(route.Path),
 			handler.PrometheusHandler(route.Path),
@@ -103,11 +103,12 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
 		)
 		)
 	}
 	}
 
 
+	chn = ng.appendAuthHandler(fr, chn, verifier)
+
 	for _, middleware := range ng.middlewares {
 	for _, middleware := range ng.middlewares {
-		chain = chain.Append(convertMiddleware(middleware))
+		chn = chn.Append(convertMiddleware(middleware))
 	}
 	}
-	chain = ng.appendAuthHandler(fr, chain, verifier)
-	handle := chain.ThenFunc(route.Handler)
+	handle := chn.ThenFunc(route.Handler)
 
 
 	return router.Handle(route.Method, route.Path, handle)
 	return router.Handle(route.Method, route.Path, handle)
 }
 }
@@ -171,16 +172,16 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
 // notFoundHandler returns a middleware that handles 404 not found requests.
 // notFoundHandler returns a middleware that handles 404 not found requests.
 func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
 func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		chain := alice.New(
+		chn := chain.New(
 			handler.TracingHandler(ng.conf.Name, ""),
 			handler.TracingHandler(ng.conf.Name, ""),
 			ng.getLogHandler(),
 			ng.getLogHandler(),
 		)
 		)
 
 
 		var h http.Handler
 		var h http.Handler
 		if next != nil {
 		if next != nil {
-			h = chain.Then(next)
+			h = chn.Then(next)
 		} else {
 		} else {
-			h = chain.Then(http.NotFoundHandler())
+			h = chn.Then(http.NotFoundHandler())
 		}
 		}
 
 
 		cw := response.NewHeaderOnceResponseWriter(w)
 		cw := response.NewHeaderOnceResponseWriter(w)
@@ -218,10 +219,10 @@ func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
 	ng.unsignedCallback = callback
 	ng.unsignedCallback = callback
 }
 }
 
 
-func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
+func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
 	if !signature.enabled {
 	if !signature.enabled {
-		return func(chain alice.Chain) alice.Chain {
-			return chain
+		return func(chn chain.Chain) chain.Chain {
+			return chn
 		}, nil
 		}, nil
 	}
 	}
 
 
@@ -230,8 +231,8 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
 			return nil, ErrSignatureConfig
 			return nil, ErrSignatureConfig
 		}
 		}
 
 
-		return func(chain alice.Chain) alice.Chain {
-			return chain
+		return func(chn chain.Chain) chain.Chain {
+			return chn
 		}, nil
 		}, nil
 	}
 	}
 
 
@@ -247,14 +248,13 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
 		decrypters[fingerprint] = decrypter
 		decrypters[fingerprint] = decrypter
 	}
 	}
 
 
-	return func(chain alice.Chain) alice.Chain {
+	return func(chn chain.Chain) chain.Chain {
 		if ng.unsignedCallback != nil {
 		if ng.unsignedCallback != nil {
-			return chain.Append(handler.ContentSecurityHandler(
+			return chn.Append(handler.ContentSecurityHandler(
 				decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
 				decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
 		}
 		}
 
 
-		return chain.Append(handler.ContentSecurityHandler(
-			decrypters, signature.Expiry, signature.Strict))
+		return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict))
 	}, nil
 	}, nil
 }
 }
 
 

+ 1 - 57
rest/engine_test.go

@@ -229,46 +229,6 @@ func TestEngine_checkedMaxBytes(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestEngine_checkedChain(t *testing.T) {
-	var called int32
-	middleware1 := func() func(http.Handler) http.Handler {
-		return func(next http.Handler) http.Handler {
-			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-				atomic.AddInt32(&called, 1)
-				next.ServeHTTP(w, r)
-				atomic.AddInt32(&called, 1)
-			})
-		}
-	}
-	middleware2 := func() func(http.Handler) http.Handler {
-		return func(next http.Handler) http.Handler {
-			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-				atomic.AddInt32(&called, 1)
-				next.ServeHTTP(w, r)
-				atomic.AddInt32(&called, 1)
-			})
-		}
-	}
-
-	server := MustNewServer(RestConf{}, DisableDefaultMiddlewares())
-	server.Use(ToMiddleware(middleware1()))
-	server.Use(ToMiddleware(middleware2()))
-	server.router = chainRouter{}
-	server.AddRoutes(
-		[]Route{
-			{
-				Method: http.MethodGet,
-				Path:   "/",
-				Handler: func(_ http.ResponseWriter, _ *http.Request) {
-					atomic.AddInt32(&called, 1)
-				},
-			},
-		},
-	)
-	server.ngin.bindRoutes(chainRouter{})
-	assert.Equal(t, int32(5), atomic.LoadInt32(&called))
-}
-
 func TestEngine_notFoundHandler(t *testing.T) {
 func TestEngine_notFoundHandler(t *testing.T) {
 	logx.Disable()
 	logx.Disable()
 
 
@@ -374,7 +334,7 @@ type mockedRouter struct{}
 func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
 func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
 }
 }
 
 
-func (m mockedRouter) Handle(_, _ string, _ http.Handler) error {
+func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
 	return errors.New("foo")
 	return errors.New("foo")
 }
 }
 
 
@@ -383,19 +343,3 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
 
 
 func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
 func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
 }
 }
-
-type chainRouter struct{}
-
-func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
-}
-
-func (c chainRouter) Handle(_, _ string, handler http.Handler) error {
-	handler.ServeHTTP(nil, nil)
-	return nil
-}
-
-func (c chainRouter) SetNotFoundHandler(_ http.Handler) {
-}
-
-func (c chainRouter) SetNotAllowedHandler(_ http.Handler) {
-}

+ 9 - 7
rest/server.go

@@ -8,6 +8,7 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/rest/chain"
 	"github.com/zeromicro/go-zero/rest/handler"
 	"github.com/zeromicro/go-zero/rest/handler"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/internal/cors"
 	"github.com/zeromicro/go-zero/rest/internal/cors"
@@ -95,13 +96,6 @@ func (s *Server) Use(middleware Middleware) {
 	s.ngin.use(middleware)
 	s.ngin.use(middleware)
 }
 }
 
 
-// DisableDefaultMiddlewares returns a RunOption that disables the builtin middlewares.
-func DisableDefaultMiddlewares() RunOption {
-	return func(svr *Server) {
-		svr.ngin.disableDefaultMiddlewares = true
-	}
-}
-
 // ToMiddleware converts the given handler to a Middleware.
 // ToMiddleware converts the given handler to a Middleware.
 func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
 func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
 	return func(handle http.HandlerFunc) http.HandlerFunc {
 	return func(handle http.HandlerFunc) http.HandlerFunc {
@@ -109,6 +103,14 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
 	}
 	}
 }
 }
 
 
+// WithChain returns a RunOption that uses the given chain to replace the default chain.
+// JWT auth middleware and the middlewares that added by svr.Use() will be appended.
+func WithChain(chn chain.Chain) RunOption {
+	return func(svr *Server) {
+		svr.ngin.chain = chn
+	}
+}
+
 // WithCors returns a func to enable CORS for given origin, or default to all origins (*).
 // WithCors returns a func to enable CORS for given origin, or default to all origins (*).
 func WithCors(origin ...string) RunOption {
 func WithCors(origin ...string) RunOption {
 	return func(server *Server) {
 	return func(server *Server) {

+ 43 - 0
rest/server_test.go

@@ -9,6 +9,7 @@ import (
 	"net/http/httptest"
 	"net/http/httptest"
 	"os"
 	"os"
 	"strings"
 	"strings"
+	"sync/atomic"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -16,6 +17,7 @@ import (
 	"github.com/zeromicro/go-zero/core/conf"
 	"github.com/zeromicro/go-zero/core/conf"
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/service"
 	"github.com/zeromicro/go-zero/core/service"
+	"github.com/zeromicro/go-zero/rest/chain"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/router"
 	"github.com/zeromicro/go-zero/rest/router"
 )
 )
@@ -435,3 +437,44 @@ func TestValidateSecret(t *testing.T) {
 		validateSecret("short")
 		validateSecret("short")
 	})
 	})
 }
 }
+
+func TestServer_WithChain(t *testing.T) {
+	var called int32
+	middleware1 := func() func(http.Handler) http.Handler {
+		return func(next http.Handler) http.Handler {
+			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				atomic.AddInt32(&called, 1)
+				next.ServeHTTP(w, r)
+				atomic.AddInt32(&called, 1)
+			})
+		}
+	}
+	middleware2 := func() func(http.Handler) http.Handler {
+		return func(next http.Handler) http.Handler {
+			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				atomic.AddInt32(&called, 1)
+				next.ServeHTTP(w, r)
+				atomic.AddInt32(&called, 1)
+			})
+		}
+	}
+
+	server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2())))
+	server.AddRoutes(
+		[]Route{
+			{
+				Method: http.MethodGet,
+				Path:   "/",
+				Handler: func(_ http.ResponseWriter, _ *http.Request) {
+					atomic.AddInt32(&called, 1)
+				},
+			},
+		},
+	)
+	rt := router.NewRouter()
+	assert.Nil(t, server.ngin.bindRoutes(rt))
+	req, err := http.NewRequest(http.MethodGet, "/", nil)
+	assert.Nil(t, err)
+	rt.ServeHTTP(httptest.NewRecorder(), req)
+	assert.Equal(t, int32(5), atomic.LoadInt32(&called))
+}