engine.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "sort"
  8. "time"
  9. "github.com/justinas/alice"
  10. "github.com/zeromicro/go-zero/core/codec"
  11. "github.com/zeromicro/go-zero/core/load"
  12. "github.com/zeromicro/go-zero/core/stat"
  13. "github.com/zeromicro/go-zero/rest/handler"
  14. "github.com/zeromicro/go-zero/rest/httpx"
  15. "github.com/zeromicro/go-zero/rest/internal"
  16. "github.com/zeromicro/go-zero/rest/internal/response"
  17. )
  18. // use 1000m to represent 100%
  19. const topCpuUsage = 1000
  20. // ErrSignatureConfig is an error that indicates bad config for signature.
  21. var ErrSignatureConfig = errors.New("bad config for Signature")
  22. type engine struct {
  23. conf RestConf
  24. routes []featuredRoutes
  25. unauthorizedCallback handler.UnauthorizedCallback
  26. unsignedCallback handler.UnsignedCallback
  27. middlewares []Middleware
  28. shedder load.Shedder
  29. priorityShedder load.Shedder
  30. tlsConfig *tls.Config
  31. }
  32. func newEngine(c RestConf) *engine {
  33. svr := &engine{
  34. conf: c,
  35. }
  36. if c.CpuThreshold > 0 {
  37. svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
  38. svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
  39. (c.CpuThreshold + topCpuUsage) >> 1))
  40. }
  41. return svr
  42. }
  43. func (ng *engine) addRoutes(r featuredRoutes) {
  44. ng.routes = append(ng.routes, r)
  45. }
  46. func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
  47. verifier func(alice.Chain) alice.Chain) alice.Chain {
  48. if fr.jwt.enabled {
  49. if len(fr.jwt.prevSecret) == 0 {
  50. chain = chain.Append(handler.Authorize(fr.jwt.secret,
  51. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  52. } else {
  53. chain = chain.Append(handler.Authorize(fr.jwt.secret,
  54. handler.WithPrevSecret(fr.jwt.prevSecret),
  55. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  56. }
  57. }
  58. return verifier(chain)
  59. }
  60. func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
  61. verifier, err := ng.signatureVerifier(fr.signature)
  62. if err != nil {
  63. return err
  64. }
  65. for _, route := range fr.routes {
  66. if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
  67. return err
  68. }
  69. }
  70. return nil
  71. }
  72. func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
  73. route Route, verifier func(chain alice.Chain) alice.Chain) error {
  74. chain := alice.New(
  75. handler.TracingHandler(ng.conf.Name, route.Path),
  76. ng.getLogHandler(),
  77. handler.PrometheusHandler(route.Path),
  78. handler.MaxConns(ng.conf.MaxConns),
  79. handler.BreakerHandler(route.Method, route.Path, metrics),
  80. handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
  81. handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
  82. handler.RecoverHandler,
  83. handler.MetricHandler(metrics),
  84. handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
  85. handler.GunzipHandler,
  86. )
  87. chain = ng.appendAuthHandler(fr, chain, verifier)
  88. for _, middleware := range ng.middlewares {
  89. chain = chain.Append(convertMiddleware(middleware))
  90. }
  91. handle := chain.ThenFunc(route.Handler)
  92. return router.Handle(route.Method, route.Path, handle)
  93. }
  94. func (ng *engine) bindRoutes(router httpx.Router) error {
  95. metrics := ng.createMetrics()
  96. for _, fr := range ng.routes {
  97. if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
  98. return err
  99. }
  100. }
  101. return nil
  102. }
  103. func (ng *engine) checkedMaxBytes(bytes int64) int64 {
  104. if bytes > 0 {
  105. return bytes
  106. }
  107. return ng.conf.MaxBytes
  108. }
  109. func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
  110. if timeout > 0 {
  111. return timeout
  112. }
  113. return time.Duration(ng.conf.Timeout) * time.Millisecond
  114. }
  115. func (ng *engine) createMetrics() *stat.Metrics {
  116. var metrics *stat.Metrics
  117. if len(ng.conf.Name) > 0 {
  118. metrics = stat.NewMetrics(ng.conf.Name)
  119. } else {
  120. metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
  121. }
  122. return metrics
  123. }
  124. func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
  125. if ng.conf.Verbose {
  126. return handler.DetailedLogHandler
  127. }
  128. return handler.LogHandler
  129. }
  130. func (ng *engine) getShedder(priority bool) load.Shedder {
  131. if priority && ng.priorityShedder != nil {
  132. return ng.priorityShedder
  133. }
  134. return ng.shedder
  135. }
  136. // notFoundHandler returns a middleware that handles 404 not found requests.
  137. func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
  138. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  139. chain := alice.New(
  140. handler.TracingHandler(ng.conf.Name, ""),
  141. ng.getLogHandler(),
  142. )
  143. var h http.Handler
  144. if next != nil {
  145. h = chain.Then(next)
  146. } else {
  147. h = chain.Then(http.NotFoundHandler())
  148. }
  149. cw := response.NewHeaderOnceResponseWriter(w)
  150. h.ServeHTTP(cw, r)
  151. cw.WriteHeader(http.StatusNotFound)
  152. })
  153. }
  154. func (ng *engine) print() {
  155. var routes []string
  156. for _, fr := range ng.routes {
  157. for _, route := range fr.routes {
  158. routes = append(routes, fmt.Sprintf("%s %s", route.Method, route.Path))
  159. }
  160. }
  161. sort.Strings(routes)
  162. for _, route := range routes {
  163. fmt.Println(route)
  164. }
  165. }
  166. func (ng *engine) setTlsConfig(cfg *tls.Config) {
  167. ng.tlsConfig = cfg
  168. }
  169. func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
  170. ng.unauthorizedCallback = callback
  171. }
  172. func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
  173. ng.unsignedCallback = callback
  174. }
  175. func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
  176. if !signature.enabled {
  177. return func(chain alice.Chain) alice.Chain {
  178. return chain
  179. }, nil
  180. }
  181. if len(signature.PrivateKeys) == 0 {
  182. if signature.Strict {
  183. return nil, ErrSignatureConfig
  184. }
  185. return func(chain alice.Chain) alice.Chain {
  186. return chain
  187. }, nil
  188. }
  189. decrypters := make(map[string]codec.RsaDecrypter)
  190. for _, key := range signature.PrivateKeys {
  191. fingerprint := key.Fingerprint
  192. file := key.KeyFile
  193. decrypter, err := codec.NewRsaDecrypter(file)
  194. if err != nil {
  195. return nil, err
  196. }
  197. decrypters[fingerprint] = decrypter
  198. }
  199. return func(chain alice.Chain) alice.Chain {
  200. if ng.unsignedCallback != nil {
  201. return chain.Append(handler.ContentSecurityHandler(
  202. decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
  203. }
  204. return chain.Append(handler.ContentSecurityHandler(
  205. decrypters, signature.Expiry, signature.Strict))
  206. }, nil
  207. }
  208. func (ng *engine) start(router httpx.Router) error {
  209. if err := ng.bindRoutes(router); err != nil {
  210. return err
  211. }
  212. if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
  213. return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, ng.withTimeout())
  214. }
  215. return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
  216. ng.conf.KeyFile, router, func(svr *http.Server) {
  217. if ng.tlsConfig != nil {
  218. svr.TLSConfig = ng.tlsConfig
  219. }
  220. }, ng.withTimeout())
  221. }
  222. func (ng *engine) use(middleware Middleware) {
  223. ng.middlewares = append(ng.middlewares, middleware)
  224. }
  225. func (ng *engine) withTimeout() internal.StartOption {
  226. return func(svr *http.Server) {
  227. timeout := ng.conf.Timeout
  228. if timeout > 0 {
  229. // factor 0.8, to avoid clients send longer content-length than the actual content,
  230. // without this timeout setting, the server will time out and respond 503 Service Unavailable,
  231. // which triggers the circuit breaker.
  232. svr.ReadTimeout = 4 * time.Duration(timeout) * time.Millisecond / 5
  233. // factor 0.9, to avoid clients not reading the response
  234. // without this timeout setting, the server will time out and respond 503 Service Unavailable,
  235. // which triggers the circuit breaker.
  236. svr.WriteTimeout = 9 * time.Duration(timeout) * time.Millisecond / 10
  237. }
  238. }
  239. }
  240. func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
  241. return func(next http.Handler) http.Handler {
  242. return ware(next.ServeHTTP)
  243. }
  244. }