소스 검색

rename rest files

kevin 4 년 전
부모
커밋
054d9b5540
4개의 변경된 파일322개의 추가작업 그리고 322개의 파일을 삭제
  1. 214 0
      rest/engine.go
  2. 0 170
      rest/ngin.go
  3. 108 152
      rest/server.go
  4. 0 0
      rest/server_test.go

+ 214 - 0
rest/engine.go

@@ -0,0 +1,214 @@
+package rest
+
+import (
+	"errors"
+	"fmt"
+	"net/http"
+	"time"
+
+	"github.com/justinas/alice"
+	"github.com/tal-tech/go-zero/core/codec"
+	"github.com/tal-tech/go-zero/core/load"
+	"github.com/tal-tech/go-zero/core/stat"
+	"github.com/tal-tech/go-zero/rest/handler"
+	"github.com/tal-tech/go-zero/rest/httpx"
+	"github.com/tal-tech/go-zero/rest/internal"
+	"github.com/tal-tech/go-zero/rest/router"
+)
+
+// use 1000m to represent 100%
+const topCpuUsage = 1000
+
+var ErrSignatureConfig = errors.New("bad config for Signature")
+
+type engine struct {
+	conf                 RestConf
+	routes               []featuredRoutes
+	unauthorizedCallback handler.UnauthorizedCallback
+	unsignedCallback     handler.UnsignedCallback
+	middlewares          []Middleware
+	shedder              load.Shedder
+	priorityShedder      load.Shedder
+}
+
+func newEngine(c RestConf) *engine {
+	srv := &engine{
+		conf: c,
+	}
+	if c.CpuThreshold > 0 {
+		srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
+		srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
+			(c.CpuThreshold + topCpuUsage) >> 1))
+	}
+
+	return srv
+}
+
+func (s *engine) AddRoutes(r featuredRoutes) {
+	s.routes = append(s.routes, r)
+}
+
+func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
+	s.unauthorizedCallback = callback
+}
+
+func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
+	s.unsignedCallback = callback
+}
+
+func (s *engine) Start() error {
+	return s.StartWithRouter(router.NewPatRouter())
+}
+
+func (s *engine) StartWithRouter(router httpx.Router) error {
+	if err := s.bindRoutes(router); err != nil {
+		return err
+	}
+
+	return internal.StartHttp(s.conf.Host, s.conf.Port, router)
+}
+
+func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
+	verifier func(alice.Chain) alice.Chain) alice.Chain {
+	if fr.jwt.enabled {
+		if len(fr.jwt.prevSecret) == 0 {
+			chain = chain.Append(handler.Authorize(fr.jwt.secret,
+				handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
+		} else {
+			chain = chain.Append(handler.Authorize(fr.jwt.secret,
+				handler.WithPrevSecret(fr.jwt.prevSecret),
+				handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
+		}
+	}
+
+	return verifier(chain)
+}
+
+func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
+	verifier, err := s.signatureVerifier(fr.signature)
+	if err != nil {
+		return err
+	}
+
+	for _, route := range fr.routes {
+		if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
+	route Route, verifier func(chain alice.Chain) alice.Chain) error {
+	chain := alice.New(
+		handler.TracingHandler,
+		s.getLogHandler(),
+		handler.MaxConns(s.conf.MaxConns),
+		handler.BreakerHandler(route.Method, route.Path, metrics),
+		handler.SheddingHandler(s.getShedder(fr.priority), metrics),
+		handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
+		handler.RecoverHandler,
+		handler.MetricHandler(metrics),
+		handler.PromMetricHandler(route.Path),
+		handler.MaxBytesHandler(s.conf.MaxBytes),
+		handler.GunzipHandler,
+	)
+	chain = s.appendAuthHandler(fr, chain, verifier)
+
+	for _, middleware := range s.middlewares {
+		chain = chain.Append(convertMiddleware(middleware))
+	}
+	handle := chain.ThenFunc(route.Handler)
+
+	return router.Handle(route.Method, route.Path, handle)
+}
+
+func (s *engine) bindRoutes(router httpx.Router) error {
+	metrics := s.createMetrics()
+
+	for _, fr := range s.routes {
+		if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (s *engine) createMetrics() *stat.Metrics {
+	var metrics *stat.Metrics
+
+	if len(s.conf.Name) > 0 {
+		metrics = stat.NewMetrics(s.conf.Name)
+	} else {
+		metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port))
+	}
+
+	return metrics
+}
+
+func (s *engine) getLogHandler() func(http.Handler) http.Handler {
+	if s.conf.Verbose {
+		return handler.DetailedLogHandler
+	} else {
+		return handler.LogHandler
+	}
+}
+
+func (s *engine) getShedder(priority bool) load.Shedder {
+	if priority && s.priorityShedder != nil {
+		return s.priorityShedder
+	}
+	return s.shedder
+}
+
+func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
+	if !signature.enabled {
+		return func(chain alice.Chain) alice.Chain {
+			return chain
+		}, nil
+	}
+
+	if len(signature.PrivateKeys) == 0 {
+		if signature.Strict {
+			return nil, ErrSignatureConfig
+		} else {
+			return func(chain alice.Chain) alice.Chain {
+				return chain
+			}, nil
+		}
+	}
+
+	decrypters := make(map[string]codec.RsaDecrypter)
+	for _, key := range signature.PrivateKeys {
+		fingerprint := key.Fingerprint
+		file := key.KeyFile
+		decrypter, err := codec.NewRsaDecrypter(file)
+		if err != nil {
+			return nil, err
+		}
+
+		decrypters[fingerprint] = decrypter
+	}
+
+	return func(chain alice.Chain) alice.Chain {
+		if s.unsignedCallback != nil {
+			return chain.Append(handler.ContentSecurityHandler(
+				decrypters, signature.Expiry, signature.Strict, s.unsignedCallback))
+		} else {
+			return chain.Append(handler.ContentSecurityHandler(
+				decrypters, signature.Expiry, signature.Strict))
+		}
+	}, nil
+}
+
+func (s *engine) use(middleware Middleware) {
+	s.middlewares = append(s.middlewares, middleware)
+}
+
+func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
+	return func(next http.Handler) http.Handler {
+		return http.HandlerFunc(ware(next.ServeHTTP))
+	}
+}

+ 0 - 170
rest/ngin.go

@@ -1,170 +0,0 @@
-package rest
-
-import (
-	"log"
-	"net/http"
-
-	"github.com/tal-tech/go-zero/core/logx"
-	"github.com/tal-tech/go-zero/rest/handler"
-	"github.com/tal-tech/go-zero/rest/httpx"
-)
-
-type (
-	runOptions struct {
-		start func(*engine) error
-	}
-
-	RunOption func(*Server)
-
-	Server struct {
-		ngin *engine
-		opts runOptions
-	}
-)
-
-func MustNewServer(c RestConf, opts ...RunOption) *Server {
-	engine, err := NewServer(c, opts...)
-	if err != nil {
-		log.Fatal(err)
-	}
-
-	return engine
-}
-
-func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
-	if err := c.SetUp(); err != nil {
-		return nil, err
-	}
-
-	server := &Server{
-		ngin: newEngine(c),
-		opts: runOptions{
-			start: func(srv *engine) error {
-				return srv.Start()
-			},
-		},
-	}
-
-	for _, opt := range opts {
-		opt(server)
-	}
-
-	return server, nil
-}
-
-func (e *Server) AddRoutes(rs []Route, opts ...RouteOption) {
-	r := featuredRoutes{
-		routes: rs,
-	}
-	for _, opt := range opts {
-		opt(&r)
-	}
-	e.ngin.AddRoutes(r)
-}
-
-func (e *Server) AddRoute(r Route, opts ...RouteOption) {
-	e.AddRoutes([]Route{r}, opts...)
-}
-
-func (e *Server) Start() {
-	handleError(e.opts.start(e.ngin))
-}
-
-func (e *Server) Stop() {
-	logx.Close()
-}
-
-func (e *Server) Use(middleware Middleware) {
-	e.ngin.use(middleware)
-}
-
-func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
-	return func(handle http.HandlerFunc) http.HandlerFunc {
-		return handler(handle).ServeHTTP
-	}
-}
-
-func WithJwt(secret string) RouteOption {
-	return func(r *featuredRoutes) {
-		validateSecret(secret)
-		r.jwt.enabled = true
-		r.jwt.secret = secret
-	}
-}
-
-func WithJwtTransition(secret, prevSecret string) RouteOption {
-	return func(r *featuredRoutes) {
-		// why not validate prevSecret, because prevSecret is an already used one,
-		// even it not meet our requirement, we still need to allow the transition.
-		validateSecret(secret)
-		r.jwt.enabled = true
-		r.jwt.secret = secret
-		r.jwt.prevSecret = prevSecret
-	}
-}
-
-func WithMiddleware(middleware Middleware, rs ...Route) []Route {
-	routes := make([]Route, len(rs))
-
-	for i := range rs {
-		route := rs[i]
-		routes[i] = Route{
-			Method:  route.Method,
-			Path:    route.Path,
-			Handler: middleware(route.Handler),
-		}
-	}
-
-	return routes
-}
-
-func WithPriority() RouteOption {
-	return func(r *featuredRoutes) {
-		r.priority = true
-	}
-}
-
-func WithRouter(router httpx.Router) RunOption {
-	return func(server *Server) {
-		server.opts.start = func(srv *engine) error {
-			return srv.StartWithRouter(router)
-		}
-	}
-}
-
-func WithSignature(signature SignatureConf) RouteOption {
-	return func(r *featuredRoutes) {
-		r.signature.enabled = true
-		r.signature.Strict = signature.Strict
-		r.signature.Expiry = signature.Expiry
-		r.signature.PrivateKeys = signature.PrivateKeys
-	}
-}
-
-func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
-	return func(engine *Server) {
-		engine.ngin.SetUnauthorizedCallback(callback)
-	}
-}
-
-func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
-	return func(engine *Server) {
-		engine.ngin.SetUnsignedCallback(callback)
-	}
-}
-
-func handleError(err error) {
-	// ErrServerClosed means the server is closed manually
-	if err == nil || err == http.ErrServerClosed {
-		return
-	}
-
-	logx.Error(err)
-	panic(err)
-}
-
-func validateSecret(secret string) {
-	if len(secret) < 8 {
-		panic("secret's length can't be less than 8")
-	}
-}

+ 108 - 152
rest/server.go

@@ -1,214 +1,170 @@
 package rest
 
 import (
-	"errors"
-	"fmt"
+	"log"
 	"net/http"
-	"time"
 
-	"github.com/justinas/alice"
-	"github.com/tal-tech/go-zero/core/codec"
-	"github.com/tal-tech/go-zero/core/load"
-	"github.com/tal-tech/go-zero/core/stat"
+	"github.com/tal-tech/go-zero/core/logx"
 	"github.com/tal-tech/go-zero/rest/handler"
 	"github.com/tal-tech/go-zero/rest/httpx"
-	"github.com/tal-tech/go-zero/rest/internal"
-	"github.com/tal-tech/go-zero/rest/router"
 )
 
-// use 1000m to represent 100%
-const topCpuUsage = 1000
+type (
+	runOptions struct {
+		start func(*engine) error
+	}
 
-var ErrSignatureConfig = errors.New("bad config for Signature")
+	RunOption func(*Server)
 
-type engine struct {
-	conf                 RestConf
-	routes               []featuredRoutes
-	unauthorizedCallback handler.UnauthorizedCallback
-	unsignedCallback     handler.UnsignedCallback
-	middlewares          []Middleware
-	shedder              load.Shedder
-	priorityShedder      load.Shedder
+	Server struct {
+		ngin *engine
+		opts runOptions
+	}
+)
+
+func MustNewServer(c RestConf, opts ...RunOption) *Server {
+	engine, err := NewServer(c, opts...)
+	if err != nil {
+		log.Fatal(err)
+	}
+
+	return engine
 }
 
-func newEngine(c RestConf) *engine {
-	srv := &engine{
-		conf: c,
+func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
+	if err := c.SetUp(); err != nil {
+		return nil, err
 	}
-	if c.CpuThreshold > 0 {
-		srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
-		srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
-			(c.CpuThreshold + topCpuUsage) >> 1))
+
+	server := &Server{
+		ngin: newEngine(c),
+		opts: runOptions{
+			start: func(srv *engine) error {
+				return srv.Start()
+			},
+		},
 	}
 
-	return srv
-}
+	for _, opt := range opts {
+		opt(server)
+	}
 
-func (s *engine) AddRoutes(r featuredRoutes) {
-	s.routes = append(s.routes, r)
+	return server, nil
 }
 
-func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
-	s.unauthorizedCallback = callback
+func (e *Server) AddRoutes(rs []Route, opts ...RouteOption) {
+	r := featuredRoutes{
+		routes: rs,
+	}
+	for _, opt := range opts {
+		opt(&r)
+	}
+	e.ngin.AddRoutes(r)
 }
 
-func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
-	s.unsignedCallback = callback
+func (e *Server) AddRoute(r Route, opts ...RouteOption) {
+	e.AddRoutes([]Route{r}, opts...)
 }
 
-func (s *engine) Start() error {
-	return s.StartWithRouter(router.NewPatRouter())
+func (e *Server) Start() {
+	handleError(e.opts.start(e.ngin))
 }
 
-func (s *engine) StartWithRouter(router httpx.Router) error {
-	if err := s.bindRoutes(router); err != nil {
-		return err
-	}
-
-	return internal.StartHttp(s.conf.Host, s.conf.Port, router)
+func (e *Server) Stop() {
+	logx.Close()
 }
 
-func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
-	verifier func(alice.Chain) alice.Chain) alice.Chain {
-	if fr.jwt.enabled {
-		if len(fr.jwt.prevSecret) == 0 {
-			chain = chain.Append(handler.Authorize(fr.jwt.secret,
-				handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
-		} else {
-			chain = chain.Append(handler.Authorize(fr.jwt.secret,
-				handler.WithPrevSecret(fr.jwt.prevSecret),
-				handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
-		}
-	}
-
-	return verifier(chain)
+func (e *Server) Use(middleware Middleware) {
+	e.ngin.use(middleware)
 }
 
-func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
-	verifier, err := s.signatureVerifier(fr.signature)
-	if err != nil {
-		return err
+func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
+	return func(handle http.HandlerFunc) http.HandlerFunc {
+		return handler(handle).ServeHTTP
 	}
+}
 
-	for _, route := range fr.routes {
-		if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil {
-			return err
-		}
+func WithJwt(secret string) RouteOption {
+	return func(r *featuredRoutes) {
+		validateSecret(secret)
+		r.jwt.enabled = true
+		r.jwt.secret = secret
 	}
-
-	return nil
 }
 
-func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
-	route Route, verifier func(chain alice.Chain) alice.Chain) error {
-	chain := alice.New(
-		handler.TracingHandler,
-		s.getLogHandler(),
-		handler.MaxConns(s.conf.MaxConns),
-		handler.BreakerHandler(route.Method, route.Path, metrics),
-		handler.SheddingHandler(s.getShedder(fr.priority), metrics),
-		handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
-		handler.RecoverHandler,
-		handler.MetricHandler(metrics),
-		handler.PromMetricHandler(route.Path),
-		handler.MaxBytesHandler(s.conf.MaxBytes),
-		handler.GunzipHandler,
-	)
-	chain = s.appendAuthHandler(fr, chain, verifier)
-
-	for _, middleware := range s.middlewares {
-		chain = chain.Append(convertMiddleware(middleware))
+func WithJwtTransition(secret, prevSecret string) RouteOption {
+	return func(r *featuredRoutes) {
+		// why not validate prevSecret, because prevSecret is an already used one,
+		// even it not meet our requirement, we still need to allow the transition.
+		validateSecret(secret)
+		r.jwt.enabled = true
+		r.jwt.secret = secret
+		r.jwt.prevSecret = prevSecret
 	}
-	handle := chain.ThenFunc(route.Handler)
-
-	return router.Handle(route.Method, route.Path, handle)
 }
 
-func (s *engine) bindRoutes(router httpx.Router) error {
-	metrics := s.createMetrics()
+func WithMiddleware(middleware Middleware, rs ...Route) []Route {
+	routes := make([]Route, len(rs))
 
-	for _, fr := range s.routes {
-		if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil {
-			return err
+	for i := range rs {
+		route := rs[i]
+		routes[i] = Route{
+			Method:  route.Method,
+			Path:    route.Path,
+			Handler: middleware(route.Handler),
 		}
 	}
 
-	return nil
+	return routes
 }
 
-func (s *engine) createMetrics() *stat.Metrics {
-	var metrics *stat.Metrics
-
-	if len(s.conf.Name) > 0 {
-		metrics = stat.NewMetrics(s.conf.Name)
-	} else {
-		metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port))
+func WithPriority() RouteOption {
+	return func(r *featuredRoutes) {
+		r.priority = true
 	}
-
-	return metrics
 }
 
-func (s *engine) getLogHandler() func(http.Handler) http.Handler {
-	if s.conf.Verbose {
-		return handler.DetailedLogHandler
-	} else {
-		return handler.LogHandler
+func WithRouter(router httpx.Router) RunOption {
+	return func(server *Server) {
+		server.opts.start = func(srv *engine) error {
+			return srv.StartWithRouter(router)
+		}
 	}
 }
 
-func (s *engine) getShedder(priority bool) load.Shedder {
-	if priority && s.priorityShedder != nil {
-		return s.priorityShedder
+func WithSignature(signature SignatureConf) RouteOption {
+	return func(r *featuredRoutes) {
+		r.signature.enabled = true
+		r.signature.Strict = signature.Strict
+		r.signature.Expiry = signature.Expiry
+		r.signature.PrivateKeys = signature.PrivateKeys
 	}
-	return s.shedder
 }
 
-func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
-	if !signature.enabled {
-		return func(chain alice.Chain) alice.Chain {
-			return chain
-		}, nil
+func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
+	return func(engine *Server) {
+		engine.ngin.SetUnauthorizedCallback(callback)
 	}
+}
 
-	if len(signature.PrivateKeys) == 0 {
-		if signature.Strict {
-			return nil, ErrSignatureConfig
-		} else {
-			return func(chain alice.Chain) alice.Chain {
-				return chain
-			}, nil
-		}
+func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
+	return func(engine *Server) {
+		engine.ngin.SetUnsignedCallback(callback)
 	}
+}
 
-	decrypters := make(map[string]codec.RsaDecrypter)
-	for _, key := range signature.PrivateKeys {
-		fingerprint := key.Fingerprint
-		file := key.KeyFile
-		decrypter, err := codec.NewRsaDecrypter(file)
-		if err != nil {
-			return nil, err
-		}
-
-		decrypters[fingerprint] = decrypter
+func handleError(err error) {
+	// ErrServerClosed means the server is closed manually
+	if err == nil || err == http.ErrServerClosed {
+		return
 	}
 
-	return func(chain alice.Chain) alice.Chain {
-		if s.unsignedCallback != nil {
-			return chain.Append(handler.ContentSecurityHandler(
-				decrypters, signature.Expiry, signature.Strict, s.unsignedCallback))
-		} else {
-			return chain.Append(handler.ContentSecurityHandler(
-				decrypters, signature.Expiry, signature.Strict))
-		}
-	}, nil
-}
-
-func (s *engine) use(middleware Middleware) {
-	s.middlewares = append(s.middlewares, middleware)
+	logx.Error(err)
+	panic(err)
 }
 
-func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
-	return func(next http.Handler) http.Handler {
-		return http.HandlerFunc(ware(next.ServeHTTP))
+func validateSecret(secret string) {
+	if len(secret) < 8 {
+		panic("secret's length can't be less than 8")
 	}
 }

+ 0 - 0
rest/ngin_test.go → rest/server_test.go