ソースを参照

feat: add middlewares config for rest (#2765)

* feat: add middlewares config for rest

* chore: disable logs in tests

* chore: enable verbose in tests
Kevin Wan 2 年 前
コミット
ade6f9ee46

+ 17 - 0
rest/config.go

@@ -7,6 +7,21 @@ import (
 )
 
 type (
+	// MiddlewaresConf is the config of middlewares.
+	MiddlewaresConf struct {
+		Trace      bool `json:",default=true"`
+		Log        bool `json:",default=true"`
+		Prometheus bool `json:",default=true"`
+		MaxConns   bool `json:",default=true"`
+		Breaker    bool `json:",default=true"`
+		Shedding   bool `json:",default=true"`
+		Timeout    bool `json:",default=true"`
+		Recover    bool `json:",default=true"`
+		Metrics    bool `json:",default=true"`
+		MaxBytes   bool `json:",default=true"`
+		Gunzip     bool `json:",default=true"`
+	}
+
 	// A PrivateKeyConf is a private key config.
 	PrivateKeyConf struct {
 		Fingerprint string
@@ -40,5 +55,7 @@ type (
 		Timeout      int64         `json:",default=3000"`
 		CpuThreshold int64         `json:",default=900,range=[0:1000]"`
 		Signature    SignatureConf `json:",optional"`
+		// There are default values for all the items in Middlewares.
+		Middlewares MiddlewaresConf
 	}
 )

+ 42 - 13
rest/engine.go

@@ -88,19 +88,7 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
 	route Route, verifier func(chain.Chain) chain.Chain) error {
 	chn := ng.chain
 	if chn == nil {
-		chn = chain.New(
-			handler.TracingHandler(ng.conf.Name, route.Path),
-			ng.getLogHandler(),
-			handler.PrometheusHandler(route.Path),
-			handler.MaxConns(ng.conf.MaxConns),
-			handler.BreakerHandler(route.Method, route.Path, metrics),
-			handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
-			handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
-			handler.RecoverHandler,
-			handler.MetricHandler(metrics),
-			handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
-			handler.GunzipHandler,
-		)
+		chn = ng.buildChainWithNativeMiddlewares(fr, route, metrics)
 	}
 
 	chn = ng.appendAuthHandler(fr, chn, verifier)
@@ -125,6 +113,47 @@ func (ng *engine) bindRoutes(router httpx.Router) error {
 	return nil
 }
 
+func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route,
+	metrics *stat.Metrics) chain.Chain {
+	chn := chain.New()
+
+	if ng.conf.Middlewares.Trace {
+		chn = chn.Append(handler.TracingHandler(ng.conf.Name, route.Path))
+	}
+	if ng.conf.Middlewares.Log {
+		chn = chn.Append(ng.getLogHandler())
+	}
+	if ng.conf.Middlewares.Prometheus {
+		chn = chn.Append(handler.PrometheusHandler(route.Path))
+	}
+	if ng.conf.Middlewares.MaxConns {
+		chn = chn.Append(handler.MaxConnsHandler(ng.conf.MaxConns))
+	}
+	if ng.conf.Middlewares.Breaker {
+		chn = chn.Append(handler.BreakerHandler(route.Method, route.Path, metrics))
+	}
+	if ng.conf.Middlewares.Shedding {
+		chn = chn.Append(handler.SheddingHandler(ng.getShedder(fr.priority), metrics))
+	}
+	if ng.conf.Middlewares.Timeout {
+		chn = chn.Append(handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)))
+	}
+	if ng.conf.Middlewares.Recover {
+		chn = chn.Append(handler.RecoverHandler)
+	}
+	if ng.conf.Middlewares.Metrics {
+		chn = chn.Append(handler.MetricHandler(metrics))
+	}
+	if ng.conf.Middlewares.MaxBytes {
+		chn = chn.Append(handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)))
+	}
+	if ng.conf.Middlewares.Gunzip {
+		chn = chn.Append(handler.GunzipHandler)
+	}
+
+	return chn
+}
+
 func (ng *engine) checkedMaxBytes(bytes int64) int64 {
 	if bytes > 0 {
 		return bytes

+ 4 - 0
rest/engine_test.go

@@ -18,10 +18,14 @@ func TestNewEngine(t *testing.T) {
 	yamls := []string{
 		`Name: foo
 Port: 54321
+Middlewares:
+  Log: false
 `,
 		`Name: foo
 Port: 54321
 CpuThreshold: 500
+Middlewares:
+  Log: false
 `,
 		`Name: foo
 Port: 54321

+ 2 - 2
rest/handler/maxconnshandler.go

@@ -8,8 +8,8 @@ import (
 	"github.com/zeromicro/go-zero/rest/internal"
 )
 
-// MaxConns returns a middleware that limit the concurrent connections.
-func MaxConns(n int) func(http.Handler) http.Handler {
+// MaxConnsHandler returns a middleware that limit the concurrent connections.
+func MaxConnsHandler(n int) func(http.Handler) http.Handler {
 	if n <= 0 {
 		return func(next http.Handler) http.Handler {
 			return next

+ 2 - 2
rest/handler/maxconnshandler_test.go

@@ -24,7 +24,7 @@ func TestMaxConnsHandler(t *testing.T) {
 	done := make(chan lang.PlaceholderType)
 	defer close(done)
 
-	maxConns := MaxConns(conns)
+	maxConns := MaxConnsHandler(conns)
 	handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		waitGroup.Done()
 		<-done
@@ -54,7 +54,7 @@ func TestWithoutMaxConnsHandler(t *testing.T) {
 	done := make(chan lang.PlaceholderType)
 	defer close(done)
 
-	maxConns := MaxConns(0)
+	maxConns := MaxConnsHandler(0)
 	handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		val := r.Header.Get(key)
 		if val == value {