engine.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "sort"
  8. "time"
  9. "github.com/zeromicro/go-zero/core/codec"
  10. "github.com/zeromicro/go-zero/core/load"
  11. "github.com/zeromicro/go-zero/core/stat"
  12. "github.com/zeromicro/go-zero/rest/chain"
  13. "github.com/zeromicro/go-zero/rest/handler"
  14. "github.com/zeromicro/go-zero/rest/httpx"
  15. "github.com/zeromicro/go-zero/rest/internal"
  16. "github.com/zeromicro/go-zero/rest/internal/response"
  17. )
  18. // use 1000m to represent 100%
  19. const topCpuUsage = 1000
  20. // ErrSignatureConfig is an error that indicates bad config for signature.
  21. var ErrSignatureConfig = errors.New("bad config for Signature")
  22. type engine struct {
  23. conf RestConf
  24. routes []featuredRoutes
  25. unauthorizedCallback handler.UnauthorizedCallback
  26. unsignedCallback handler.UnsignedCallback
  27. chain chain.Chain
  28. middlewares []Middleware
  29. shedder load.Shedder
  30. priorityShedder load.Shedder
  31. tlsConfig *tls.Config
  32. }
  33. func newEngine(c RestConf) *engine {
  34. svr := &engine{
  35. conf: c,
  36. }
  37. if c.CpuThreshold > 0 {
  38. svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
  39. svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
  40. (c.CpuThreshold + topCpuUsage) >> 1))
  41. }
  42. return svr
  43. }
  44. func (ng *engine) addRoutes(r featuredRoutes) {
  45. ng.routes = append(ng.routes, r)
  46. }
  47. func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
  48. verifier func(chain.Chain) chain.Chain) chain.Chain {
  49. if fr.jwt.enabled {
  50. if len(fr.jwt.prevSecret) == 0 {
  51. chn = chn.Append(handler.Authorize(fr.jwt.secret,
  52. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  53. } else {
  54. chn = chn.Append(handler.Authorize(fr.jwt.secret,
  55. handler.WithPrevSecret(fr.jwt.prevSecret),
  56. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  57. }
  58. }
  59. return verifier(chn)
  60. }
  61. func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
  62. verifier, err := ng.signatureVerifier(fr.signature)
  63. if err != nil {
  64. return err
  65. }
  66. for _, route := range fr.routes {
  67. if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
  68. return err
  69. }
  70. }
  71. return nil
  72. }
  73. func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
  74. route Route, verifier func(chain.Chain) chain.Chain) error {
  75. chn := ng.chain
  76. if chn == nil {
  77. chn = chain.New(
  78. handler.TracingHandler(ng.conf.Name, route.Path),
  79. ng.getLogHandler(),
  80. handler.PrometheusHandler(route.Path),
  81. handler.MaxConns(ng.conf.MaxConns),
  82. handler.BreakerHandler(route.Method, route.Path, metrics),
  83. handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
  84. handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
  85. handler.RecoverHandler,
  86. handler.MetricHandler(metrics),
  87. handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
  88. handler.GunzipHandler,
  89. )
  90. }
  91. chn = ng.appendAuthHandler(fr, chn, verifier)
  92. for _, middleware := range ng.middlewares {
  93. chn = chn.Append(convertMiddleware(middleware))
  94. }
  95. handle := chn.ThenFunc(route.Handler)
  96. return router.Handle(route.Method, route.Path, handle)
  97. }
  98. func (ng *engine) bindRoutes(router httpx.Router) error {
  99. metrics := ng.createMetrics()
  100. for _, fr := range ng.routes {
  101. if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
  102. return err
  103. }
  104. }
  105. return nil
  106. }
  107. func (ng *engine) checkedMaxBytes(bytes int64) int64 {
  108. if bytes > 0 {
  109. return bytes
  110. }
  111. return ng.conf.MaxBytes
  112. }
  113. func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
  114. if timeout > 0 {
  115. return timeout
  116. }
  117. return time.Duration(ng.conf.Timeout) * time.Millisecond
  118. }
  119. func (ng *engine) createMetrics() *stat.Metrics {
  120. var metrics *stat.Metrics
  121. if len(ng.conf.Name) > 0 {
  122. metrics = stat.NewMetrics(ng.conf.Name)
  123. } else {
  124. metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
  125. }
  126. return metrics
  127. }
  128. func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
  129. if ng.conf.Verbose {
  130. return handler.DetailedLogHandler
  131. }
  132. return handler.LogHandler
  133. }
  134. func (ng *engine) getShedder(priority bool) load.Shedder {
  135. if priority && ng.priorityShedder != nil {
  136. return ng.priorityShedder
  137. }
  138. return ng.shedder
  139. }
  140. // notFoundHandler returns a middleware that handles 404 not found requests.
  141. func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
  142. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  143. chn := chain.New(
  144. handler.TracingHandler(ng.conf.Name, ""),
  145. ng.getLogHandler(),
  146. )
  147. var h http.Handler
  148. if next != nil {
  149. h = chn.Then(next)
  150. } else {
  151. h = chn.Then(http.NotFoundHandler())
  152. }
  153. cw := response.NewHeaderOnceResponseWriter(w)
  154. h.ServeHTTP(cw, r)
  155. cw.WriteHeader(http.StatusNotFound)
  156. })
  157. }
  158. func (ng *engine) print() {
  159. var routes []string
  160. for _, fr := range ng.routes {
  161. for _, route := range fr.routes {
  162. routes = append(routes, fmt.Sprintf("%s %s", route.Method, route.Path))
  163. }
  164. }
  165. sort.Strings(routes)
  166. fmt.Println("Routes:")
  167. for _, route := range routes {
  168. fmt.Printf(" %s\n", route)
  169. }
  170. }
  171. func (ng *engine) setTlsConfig(cfg *tls.Config) {
  172. ng.tlsConfig = cfg
  173. }
  174. func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
  175. ng.unauthorizedCallback = callback
  176. }
  177. func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
  178. ng.unsignedCallback = callback
  179. }
  180. func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
  181. if !signature.enabled {
  182. return func(chn chain.Chain) chain.Chain {
  183. return chn
  184. }, nil
  185. }
  186. if len(signature.PrivateKeys) == 0 {
  187. if signature.Strict {
  188. return nil, ErrSignatureConfig
  189. }
  190. return func(chn chain.Chain) chain.Chain {
  191. return chn
  192. }, nil
  193. }
  194. decrypters := make(map[string]codec.RsaDecrypter)
  195. for _, key := range signature.PrivateKeys {
  196. fingerprint := key.Fingerprint
  197. file := key.KeyFile
  198. decrypter, err := codec.NewRsaDecrypter(file)
  199. if err != nil {
  200. return nil, err
  201. }
  202. decrypters[fingerprint] = decrypter
  203. }
  204. return func(chn chain.Chain) chain.Chain {
  205. if ng.unsignedCallback != nil {
  206. return chn.Append(handler.ContentSecurityHandler(
  207. decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
  208. }
  209. return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict))
  210. }, nil
  211. }
  212. func (ng *engine) start(router httpx.Router) error {
  213. if err := ng.bindRoutes(router); err != nil {
  214. return err
  215. }
  216. if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
  217. return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, ng.withTimeout())
  218. }
  219. return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
  220. ng.conf.KeyFile, router, func(svr *http.Server) {
  221. if ng.tlsConfig != nil {
  222. svr.TLSConfig = ng.tlsConfig
  223. }
  224. }, ng.withTimeout())
  225. }
  226. func (ng *engine) use(middleware Middleware) {
  227. ng.middlewares = append(ng.middlewares, middleware)
  228. }
  229. func (ng *engine) withTimeout() internal.StartOption {
  230. return func(svr *http.Server) {
  231. timeout := ng.conf.Timeout
  232. if timeout > 0 {
  233. // factor 0.8, to avoid clients send longer content-length than the actual content,
  234. // without this timeout setting, the server will time out and respond 503 Service Unavailable,
  235. // which triggers the circuit breaker.
  236. svr.ReadTimeout = 4 * time.Duration(timeout) * time.Millisecond / 5
  237. // factor 0.9, to avoid clients not reading the response
  238. // without this timeout setting, the server will time out and respond 503 Service Unavailable,
  239. // which triggers the circuit breaker.
  240. svr.WriteTimeout = 9 * time.Duration(timeout) * time.Millisecond / 10
  241. }
  242. }
  243. }
  244. func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
  245. return func(next http.Handler) http.Handler {
  246. return ware(next.ServeHTTP)
  247. }
  248. }