浏览代码

refactor: simplify tls config in rest (#1181)

Kevin Wan 3 年之前
父节点
当前提交
769d06c8ab
共有 6 个文件被更改,包括 96 次插入78 次删除
  1. 1 0
      go.mod
  2. 1 1
      rest/config.go
  3. 56 47
      rest/engine.go
  4. 19 11
      rest/internal/starter.go
  5. 14 16
      rest/server.go
  6. 5 3
      rest/server_test.go

+ 1 - 0
go.mod

@@ -54,4 +54,5 @@ require (
 	k8s.io/api v0.20.10
 	k8s.io/apimachinery v0.20.10
 	k8s.io/client-go v0.20.10
+	k8s.io/utils v0.0.0-20201110183641-67b214c5f920
 )

+ 1 - 1
rest/config.go

@@ -35,7 +35,7 @@ type (
 		KeyFile  string `json:",optional"`
 		Verbose  bool   `json:",optional"`
 		MaxConns int    `json:",default=10000"`
-		MaxBytes int64  `json:",default=1048576,range=[0:33554432]"`
+		MaxBytes int64  `json:",default=1048576"`
 		// milliseconds
 		Timeout      int64         `json:",default=3000"`
 		CpuThreshold int64         `json:",default=900,range=[0:1000]"`

+ 56 - 47
rest/engine.go

@@ -47,58 +47,63 @@ func newEngine(c RestConf) *engine {
 	return srv
 }
 
-func (s *engine) AddRoutes(r featuredRoutes) {
-	s.routes = append(s.routes, r)
+func (ng *engine) AddRoutes(r featuredRoutes) {
+	ng.routes = append(ng.routes, r)
 }
 
-func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
-	s.unauthorizedCallback = callback
+func (ng *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
+	ng.unauthorizedCallback = callback
 }
 
-func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
-	s.unsignedCallback = callback
+func (ng *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
+	ng.unsignedCallback = callback
 }
 
-func (s *engine) Start() error {
-	return s.StartWithRouter(router.NewRouter())
+func (ng *engine) Start() error {
+	return ng.StartWithRouter(router.NewRouter())
 }
 
-func (s *engine) StartWithRouter(router httpx.Router) error {
-	if err := s.bindRoutes(router); err != nil {
+func (ng *engine) StartWithRouter(router httpx.Router) error {
+	if err := ng.bindRoutes(router); err != nil {
 		return err
 	}
 
-	if len(s.conf.CertFile) == 0 && len(s.conf.KeyFile) == 0 {
-		return internal.StartHttp(s.conf.Host, s.conf.Port, router)
+	if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
+		return internal.StartHttp(ng.conf.Host, ng.conf.Port, router)
 	}
 
-	return internal.StartHttps(s.conf.Host, s.conf.Port, s.conf.CertFile, s.conf.KeyFile, s.tlsConfig, router)
+	return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
+		ng.conf.KeyFile, router, func(srv *http.Server) {
+			if ng.tlsConfig != nil {
+				srv.TLSConfig = ng.tlsConfig
+			}
+		})
 }
 
-func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
+func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
 	verifier func(alice.Chain) alice.Chain) alice.Chain {
 	if fr.jwt.enabled {
 		if len(fr.jwt.prevSecret) == 0 {
 			chain = chain.Append(handler.Authorize(fr.jwt.secret,
-				handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
+				handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
 		} else {
 			chain = chain.Append(handler.Authorize(fr.jwt.secret,
 				handler.WithPrevSecret(fr.jwt.prevSecret),
-				handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
+				handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
 		}
 	}
 
 	return verifier(chain)
 }
 
-func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
-	verifier, err := s.signatureVerifier(fr.signature)
+func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
+	verifier, err := ng.signatureVerifier(fr.signature)
 	if err != nil {
 		return err
 	}
 
 	for _, route := range fr.routes {
-		if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil {
+		if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
 			return err
 		}
 	}
@@ -106,24 +111,24 @@ func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metr
 	return nil
 }
 
-func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
+func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
 	route Route, verifier func(chain alice.Chain) alice.Chain) error {
 	chain := alice.New(
-		handler.TracingHandler(s.conf.Name, route.Path),
-		s.getLogHandler(),
+		handler.TracingHandler(ng.conf.Name, route.Path),
+		ng.getLogHandler(),
 		handler.PrometheusHandler(route.Path),
-		handler.MaxConns(s.conf.MaxConns),
+		handler.MaxConns(ng.conf.MaxConns),
 		handler.BreakerHandler(route.Method, route.Path, metrics),
-		handler.SheddingHandler(s.getShedder(fr.priority), metrics),
-		handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
+		handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
+		handler.TimeoutHandler(time.Duration(ng.conf.Timeout)*time.Millisecond),
 		handler.RecoverHandler,
 		handler.MetricHandler(metrics),
-		handler.MaxBytesHandler(s.conf.MaxBytes),
+		handler.MaxBytesHandler(ng.conf.MaxBytes),
 		handler.GunzipHandler,
 	)
-	chain = s.appendAuthHandler(fr, chain, verifier)
+	chain = ng.appendAuthHandler(fr, chain, verifier)
 
-	for _, middleware := range s.middlewares {
+	for _, middleware := range ng.middlewares {
 		chain = chain.Append(convertMiddleware(middleware))
 	}
 	handle := chain.ThenFunc(route.Handler)
@@ -131,11 +136,11 @@ func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat
 	return router.Handle(route.Method, route.Path, handle)
 }
 
-func (s *engine) bindRoutes(router httpx.Router) error {
-	metrics := s.createMetrics()
+func (ng *engine) bindRoutes(router httpx.Router) error {
+	metrics := ng.createMetrics()
 
-	for _, fr := range s.routes {
-		if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil {
+	for _, fr := range ng.routes {
+		if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
 			return err
 		}
 	}
@@ -143,35 +148,39 @@ func (s *engine) bindRoutes(router httpx.Router) error {
 	return nil
 }
 
-func (s *engine) createMetrics() *stat.Metrics {
+func (ng *engine) createMetrics() *stat.Metrics {
 	var metrics *stat.Metrics
 
-	if len(s.conf.Name) > 0 {
-		metrics = stat.NewMetrics(s.conf.Name)
+	if len(ng.conf.Name) > 0 {
+		metrics = stat.NewMetrics(ng.conf.Name)
 	} else {
-		metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port))
+		metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
 	}
 
 	return metrics
 }
 
-func (s *engine) getLogHandler() func(http.Handler) http.Handler {
-	if s.conf.Verbose {
+func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
+	if ng.conf.Verbose {
 		return handler.DetailedLogHandler
 	}
 
 	return handler.LogHandler
 }
 
-func (s *engine) getShedder(priority bool) load.Shedder {
-	if priority && s.priorityShedder != nil {
-		return s.priorityShedder
+func (ng *engine) getShedder(priority bool) load.Shedder {
+	if priority && ng.priorityShedder != nil {
+		return ng.priorityShedder
 	}
 
-	return s.shedder
+	return ng.shedder
 }
 
-func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
+func (ng *engine) setTlsConfig(cfg *tls.Config) {
+	ng.tlsConfig = cfg
+}
+
+func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
 	if !signature.enabled {
 		return func(chain alice.Chain) alice.Chain {
 			return chain
@@ -201,9 +210,9 @@ func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice
 	}
 
 	return func(chain alice.Chain) alice.Chain {
-		if s.unsignedCallback != nil {
+		if ng.unsignedCallback != nil {
 			return chain.Append(handler.ContentSecurityHandler(
-				decrypters, signature.Expiry, signature.Strict, s.unsignedCallback))
+				decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
 		}
 
 		return chain.Append(handler.ContentSecurityHandler(
@@ -211,8 +220,8 @@ func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice
 	}, nil
 }
 
-func (s *engine) use(middleware Middleware) {
-	s.middlewares = append(s.middlewares, middleware)
+func (ng *engine) use(middleware Middleware) {
+	ng.middlewares = append(ng.middlewares, middleware)
 }
 
 func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {

+ 19 - 11
rest/internal/starter.go

@@ -2,38 +2,46 @@ package internal
 
 import (
 	"context"
-	"crypto/tls"
 	"fmt"
 	"net/http"
 
+	"github.com/tal-tech/go-zero/core/logx"
 	"github.com/tal-tech/go-zero/core/proc"
 )
 
+// StartOption defines the method to customize http.Server.
+type StartOption func(srv *http.Server)
+
 // StartHttp starts a http server.
-func StartHttp(host string, port int, handler http.Handler) error {
-	return start(host, port, handler, nil, func(srv *http.Server) error {
+func StartHttp(host string, port int, handler http.Handler, opts ...StartOption) error {
+	return start(host, port, handler, func(srv *http.Server) error {
 		return srv.ListenAndServe()
-	})
+	}, opts...)
 }
 
 // StartHttps starts a https server.
-func StartHttps(host string, port int, certFile, keyFile string, tlsConfig *tls.Config, handler http.Handler) error {
-	return start(host, port, handler, tlsConfig, func(srv *http.Server) error {
+func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler,
+	opts ...StartOption) error {
+	return start(host, port, handler, func(srv *http.Server) error {
 		// certFile and keyFile are set in buildHttpsServer
 		return srv.ListenAndServeTLS(certFile, keyFile)
-	})
+	}, opts...)
 }
 
-func start(host string, port int, handler http.Handler, tlsConfig *tls.Config, run func(srv *http.Server) error) (err error) {
+func start(host string, port int, handler http.Handler, run func(srv *http.Server) error,
+	opts ...StartOption) (err error) {
 	server := &http.Server{
 		Addr:    fmt.Sprintf("%s:%d", host, port),
 		Handler: handler,
 	}
-	if tlsConfig != nil {
-		server.TLSConfig = tlsConfig
+	for _, opt := range opts {
+		opt(server)
 	}
+
 	waitForCalled := proc.AddWrapUpListener(func() {
-		server.Shutdown(context.Background())
+		if e := server.Shutdown(context.Background()); err != nil {
+			logx.Error(e)
+		}
 	})
 	defer func() {
 		if err == http.ErrServerClosed {

+ 14 - 16
rest/server.go

@@ -48,8 +48,8 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
 	server := &Server{
 		ngin: newEngine(c),
 		opts: runOptions{
-			start: func(srv *engine) error {
-				return srv.Start()
+			start: func(ng *engine) error {
+				return ng.Start()
 			},
 		},
 	}
@@ -171,8 +171,8 @@ func WithPriority() RouteOption {
 // WithRouter returns a RunOption that make server run with given router.
 func WithRouter(router httpx.Router) RunOption {
 	return func(server *Server) {
-		server.opts.start = func(srv *engine) error {
-			return srv.StartWithRouter(router)
+		server.opts.start = func(ng *engine) error {
+			return ng.StartWithRouter(router)
 		}
 	}
 }
@@ -187,26 +187,24 @@ func WithSignature(signature SignatureConf) RouteOption {
 	}
 }
 
-// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
-func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
-	return func(engine *Server) {
-		engine.ngin.SetUnauthorizedCallback(callback)
+// WithTLSConfig returns a RunOption that with given tls config.
+func WithTLSConfig(cfg *tls.Config) RunOption {
+	return func(srv *Server) {
+		srv.ngin.setTlsConfig(cfg)
 	}
 }
 
-// WithTLSConfig returns a RunOption that with given tls config.
-func WithTLSConfig(cipherSuites []uint16) RunOption {
-	return func(engine *Server) {
-		engine.ngin.tlsConfig = &tls.Config{
-			CipherSuites: cipherSuites,
-		}
+// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
+func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
+	return func(srv *Server) {
+		srv.ngin.SetUnauthorizedCallback(callback)
 	}
 }
 
 // WithUnsignedCallback returns a RunOption that with given unsigned callback set.
 func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
-	return func(engine *Server) {
-		engine.ngin.SetUnsignedCallback(callback)
+	return func(srv *Server) {
+		srv.ngin.SetUnsignedCallback(callback)
 	}
 }
 

+ 5 - 3
rest/server_test.go

@@ -227,8 +227,10 @@ Port: 54321
 	var cnf RestConf
 	assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
 
-	testConfig := []uint16{
-		tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+	testConfig := &tls.Config{
+		CipherSuites: []uint16{
+			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+		},
 	}
 
 	testCases := []struct {
@@ -239,7 +241,7 @@ Port: 54321
 		{
 			c:    cnf,
 			opts: []RunOption{WithTLSConfig(testConfig)},
-			res:  &tls.Config{CipherSuites: testConfig},
+			res:  testConfig,
 		},
 		{
 			c:    cnf,