engine.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "time"
  8. "github.com/justinas/alice"
  9. "github.com/tal-tech/go-zero/core/codec"
  10. "github.com/tal-tech/go-zero/core/load"
  11. "github.com/tal-tech/go-zero/core/stat"
  12. "github.com/tal-tech/go-zero/rest/handler"
  13. "github.com/tal-tech/go-zero/rest/httpx"
  14. "github.com/tal-tech/go-zero/rest/internal"
  15. )
  16. // use 1000m to represent 100%
  17. const topCpuUsage = 1000
  18. // ErrSignatureConfig is an error that indicates bad config for signature.
  19. var ErrSignatureConfig = errors.New("bad config for Signature")
  20. type engine struct {
  21. conf RestConf
  22. routes []featuredRoutes
  23. unauthorizedCallback handler.UnauthorizedCallback
  24. unsignedCallback handler.UnsignedCallback
  25. middlewares []Middleware
  26. shedder load.Shedder
  27. priorityShedder load.Shedder
  28. tlsConfig *tls.Config
  29. }
  30. func newEngine(c RestConf) *engine {
  31. srv := &engine{
  32. conf: c,
  33. }
  34. if c.CpuThreshold > 0 {
  35. srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
  36. srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
  37. (c.CpuThreshold + topCpuUsage) >> 1))
  38. }
  39. return srv
  40. }
  41. func (ng *engine) addRoutes(r featuredRoutes) {
  42. ng.routes = append(ng.routes, r)
  43. }
  44. func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
  45. verifier func(alice.Chain) alice.Chain) alice.Chain {
  46. if fr.jwt.enabled {
  47. if len(fr.jwt.prevSecret) == 0 {
  48. chain = chain.Append(handler.Authorize(fr.jwt.secret,
  49. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  50. } else {
  51. chain = chain.Append(handler.Authorize(fr.jwt.secret,
  52. handler.WithPrevSecret(fr.jwt.prevSecret),
  53. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  54. }
  55. }
  56. return verifier(chain)
  57. }
  58. func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
  59. verifier, err := ng.signatureVerifier(fr.signature)
  60. if err != nil {
  61. return err
  62. }
  63. for _, route := range fr.routes {
  64. if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
  65. return err
  66. }
  67. }
  68. return nil
  69. }
  70. func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
  71. route Route, verifier func(chain alice.Chain) alice.Chain) error {
  72. chain := alice.New(
  73. handler.TracingHandler(ng.conf.Name, route.Path),
  74. ng.getLogHandler(),
  75. handler.PrometheusHandler(route.Path),
  76. handler.MaxConns(ng.conf.MaxConns),
  77. handler.BreakerHandler(route.Method, route.Path, metrics),
  78. handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
  79. handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
  80. handler.RecoverHandler,
  81. handler.MetricHandler(metrics),
  82. handler.MaxBytesHandler(ng.conf.MaxBytes),
  83. handler.GunzipHandler,
  84. )
  85. chain = ng.appendAuthHandler(fr, chain, verifier)
  86. for _, middleware := range ng.middlewares {
  87. chain = chain.Append(convertMiddleware(middleware))
  88. }
  89. handle := chain.ThenFunc(route.Handler)
  90. return router.Handle(route.Method, route.Path, handle)
  91. }
  92. func (ng *engine) bindRoutes(router httpx.Router) error {
  93. metrics := ng.createMetrics()
  94. for _, fr := range ng.routes {
  95. if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
  96. return err
  97. }
  98. }
  99. return nil
  100. }
  101. func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
  102. if timeout > 0 {
  103. return timeout
  104. }
  105. return time.Duration(ng.conf.Timeout) * time.Millisecond
  106. }
  107. func (ng *engine) createMetrics() *stat.Metrics {
  108. var metrics *stat.Metrics
  109. if len(ng.conf.Name) > 0 {
  110. metrics = stat.NewMetrics(ng.conf.Name)
  111. } else {
  112. metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
  113. }
  114. return metrics
  115. }
  116. func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
  117. if ng.conf.Verbose {
  118. return handler.DetailedLogHandler
  119. }
  120. return handler.LogHandler
  121. }
  122. func (ng *engine) getShedder(priority bool) load.Shedder {
  123. if priority && ng.priorityShedder != nil {
  124. return ng.priorityShedder
  125. }
  126. return ng.shedder
  127. }
  128. func (ng *engine) setTlsConfig(cfg *tls.Config) {
  129. ng.tlsConfig = cfg
  130. }
  131. func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
  132. ng.unauthorizedCallback = callback
  133. }
  134. func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
  135. ng.unsignedCallback = callback
  136. }
  137. func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
  138. if !signature.enabled {
  139. return func(chain alice.Chain) alice.Chain {
  140. return chain
  141. }, nil
  142. }
  143. if len(signature.PrivateKeys) == 0 {
  144. if signature.Strict {
  145. return nil, ErrSignatureConfig
  146. }
  147. return func(chain alice.Chain) alice.Chain {
  148. return chain
  149. }, nil
  150. }
  151. decrypters := make(map[string]codec.RsaDecrypter)
  152. for _, key := range signature.PrivateKeys {
  153. fingerprint := key.Fingerprint
  154. file := key.KeyFile
  155. decrypter, err := codec.NewRsaDecrypter(file)
  156. if err != nil {
  157. return nil, err
  158. }
  159. decrypters[fingerprint] = decrypter
  160. }
  161. return func(chain alice.Chain) alice.Chain {
  162. if ng.unsignedCallback != nil {
  163. return chain.Append(handler.ContentSecurityHandler(
  164. decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
  165. }
  166. return chain.Append(handler.ContentSecurityHandler(
  167. decrypters, signature.Expiry, signature.Strict))
  168. }, nil
  169. }
  170. func (ng *engine) start(router httpx.Router) error {
  171. if err := ng.bindRoutes(router); err != nil {
  172. return err
  173. }
  174. if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
  175. return internal.StartHttp(ng.conf.Host, ng.conf.Port, router)
  176. }
  177. return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
  178. ng.conf.KeyFile, router, func(srv *http.Server) {
  179. if ng.tlsConfig != nil {
  180. srv.TLSConfig = ng.tlsConfig
  181. }
  182. })
  183. }
  184. func (ng *engine) use(middleware Middleware) {
  185. ng.middlewares = append(ng.middlewares, middleware)
  186. }
  187. func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
  188. return func(next http.Handler) http.Handler {
  189. return ware(next.ServeHTTP)
  190. }
  191. }