engine.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "sort"
  8. "time"
  9. "github.com/zeromicro/go-zero/core/codec"
  10. "github.com/zeromicro/go-zero/core/load"
  11. "github.com/zeromicro/go-zero/core/stat"
  12. "github.com/zeromicro/go-zero/rest/chain"
  13. "github.com/zeromicro/go-zero/rest/handler"
  14. "github.com/zeromicro/go-zero/rest/httpx"
  15. "github.com/zeromicro/go-zero/rest/internal"
  16. "github.com/zeromicro/go-zero/rest/internal/response"
  17. )
  18. // use 1000m to represent 100%
  19. const topCpuUsage = 1000
  20. // ErrSignatureConfig is an error that indicates bad config for signature.
  21. var ErrSignatureConfig = errors.New("bad config for Signature")
  22. type engine struct {
  23. conf RestConf
  24. routes []featuredRoutes
  25. unauthorizedCallback handler.UnauthorizedCallback
  26. unsignedCallback handler.UnsignedCallback
  27. chain chain.Chain
  28. middlewares []Middleware
  29. shedder load.Shedder
  30. priorityShedder load.Shedder
  31. tlsConfig *tls.Config
  32. }
  33. func newEngine(c RestConf) *engine {
  34. svr := &engine{
  35. conf: c,
  36. }
  37. if c.CpuThreshold > 0 {
  38. svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
  39. svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
  40. (c.CpuThreshold + topCpuUsage) >> 1))
  41. }
  42. return svr
  43. }
  44. func (ng *engine) addRoutes(r featuredRoutes) {
  45. ng.routes = append(ng.routes, r)
  46. }
  47. func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
  48. verifier func(chain.Chain) chain.Chain) chain.Chain {
  49. if fr.jwt.enabled {
  50. if len(fr.jwt.prevSecret) == 0 {
  51. chn = chn.Append(handler.Authorize(fr.jwt.secret,
  52. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  53. } else {
  54. chn = chn.Append(handler.Authorize(fr.jwt.secret,
  55. handler.WithPrevSecret(fr.jwt.prevSecret),
  56. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  57. }
  58. }
  59. return verifier(chn)
  60. }
  61. func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
  62. verifier, err := ng.signatureVerifier(fr.signature)
  63. if err != nil {
  64. return err
  65. }
  66. for _, route := range fr.routes {
  67. if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
  68. return err
  69. }
  70. }
  71. return nil
  72. }
  73. func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
  74. route Route, verifier func(chain.Chain) chain.Chain) error {
  75. chn := ng.chain
  76. if chn == nil {
  77. chn = ng.buildChainWithNativeMiddlewares(fr, route, metrics)
  78. }
  79. chn = ng.appendAuthHandler(fr, chn, verifier)
  80. for _, middleware := range ng.middlewares {
  81. chn = chn.Append(convertMiddleware(middleware))
  82. }
  83. handle := chn.ThenFunc(route.Handler)
  84. return router.Handle(route.Method, route.Path, handle)
  85. }
  86. func (ng *engine) bindRoutes(router httpx.Router) error {
  87. metrics := ng.createMetrics()
  88. for _, fr := range ng.routes {
  89. if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
  90. return err
  91. }
  92. }
  93. return nil
  94. }
  95. func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route,
  96. metrics *stat.Metrics) chain.Chain {
  97. chn := chain.New()
  98. if ng.conf.Middlewares.Trace {
  99. chn = chn.Append(handler.TracingHandler(ng.conf.Name, route.Path))
  100. }
  101. if ng.conf.Middlewares.Log {
  102. chn = chn.Append(ng.getLogHandler())
  103. }
  104. if ng.conf.Middlewares.Prometheus {
  105. chn = chn.Append(handler.PrometheusHandler(route.Path))
  106. }
  107. if ng.conf.Middlewares.MaxConns {
  108. chn = chn.Append(handler.MaxConnsHandler(ng.conf.MaxConns))
  109. }
  110. if ng.conf.Middlewares.Breaker {
  111. chn = chn.Append(handler.BreakerHandler(route.Method, route.Path, metrics))
  112. }
  113. if ng.conf.Middlewares.Shedding {
  114. chn = chn.Append(handler.SheddingHandler(ng.getShedder(fr.priority), metrics))
  115. }
  116. if ng.conf.Middlewares.Timeout {
  117. chn = chn.Append(handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)))
  118. }
  119. if ng.conf.Middlewares.Recover {
  120. chn = chn.Append(handler.RecoverHandler)
  121. }
  122. if ng.conf.Middlewares.Metrics {
  123. chn = chn.Append(handler.MetricHandler(metrics))
  124. }
  125. if ng.conf.Middlewares.MaxBytes {
  126. chn = chn.Append(handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)))
  127. }
  128. if ng.conf.Middlewares.Gunzip {
  129. chn = chn.Append(handler.GunzipHandler)
  130. }
  131. return chn
  132. }
  133. func (ng *engine) checkedMaxBytes(bytes int64) int64 {
  134. if bytes > 0 {
  135. return bytes
  136. }
  137. return ng.conf.MaxBytes
  138. }
  139. func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
  140. if timeout > 0 {
  141. return timeout
  142. }
  143. return time.Duration(ng.conf.Timeout) * time.Millisecond
  144. }
  145. func (ng *engine) createMetrics() *stat.Metrics {
  146. var metrics *stat.Metrics
  147. if len(ng.conf.Name) > 0 {
  148. metrics = stat.NewMetrics(ng.conf.Name)
  149. } else {
  150. metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
  151. }
  152. return metrics
  153. }
  154. func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
  155. if ng.conf.Verbose {
  156. return handler.DetailedLogHandler
  157. }
  158. return handler.LogHandler
  159. }
  160. func (ng *engine) getShedder(priority bool) load.Shedder {
  161. if priority && ng.priorityShedder != nil {
  162. return ng.priorityShedder
  163. }
  164. return ng.shedder
  165. }
  166. // notFoundHandler returns a middleware that handles 404 not found requests.
  167. func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
  168. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  169. chn := chain.New(
  170. handler.TracingHandler(ng.conf.Name, ""),
  171. ng.getLogHandler(),
  172. )
  173. var h http.Handler
  174. if next != nil {
  175. h = chn.Then(next)
  176. } else {
  177. h = chn.Then(http.NotFoundHandler())
  178. }
  179. cw := response.NewHeaderOnceResponseWriter(w)
  180. h.ServeHTTP(cw, r)
  181. cw.WriteHeader(http.StatusNotFound)
  182. })
  183. }
  184. func (ng *engine) print() {
  185. var routes []string
  186. for _, fr := range ng.routes {
  187. for _, route := range fr.routes {
  188. routes = append(routes, fmt.Sprintf("%s %s", route.Method, route.Path))
  189. }
  190. }
  191. sort.Strings(routes)
  192. fmt.Println("Routes:")
  193. for _, route := range routes {
  194. fmt.Printf(" %s\n", route)
  195. }
  196. }
  197. func (ng *engine) setTlsConfig(cfg *tls.Config) {
  198. ng.tlsConfig = cfg
  199. }
  200. func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
  201. ng.unauthorizedCallback = callback
  202. }
  203. func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
  204. ng.unsignedCallback = callback
  205. }
  206. func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
  207. if !signature.enabled {
  208. return func(chn chain.Chain) chain.Chain {
  209. return chn
  210. }, nil
  211. }
  212. if len(signature.PrivateKeys) == 0 {
  213. if signature.Strict {
  214. return nil, ErrSignatureConfig
  215. }
  216. return func(chn chain.Chain) chain.Chain {
  217. return chn
  218. }, nil
  219. }
  220. decrypters := make(map[string]codec.RsaDecrypter)
  221. for _, key := range signature.PrivateKeys {
  222. fingerprint := key.Fingerprint
  223. file := key.KeyFile
  224. decrypter, err := codec.NewRsaDecrypter(file)
  225. if err != nil {
  226. return nil, err
  227. }
  228. decrypters[fingerprint] = decrypter
  229. }
  230. return func(chn chain.Chain) chain.Chain {
  231. if ng.unsignedCallback != nil {
  232. return chn.Append(handler.ContentSecurityHandler(
  233. decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
  234. }
  235. return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict))
  236. }, nil
  237. }
  238. func (ng *engine) start(router httpx.Router) error {
  239. if err := ng.bindRoutes(router); err != nil {
  240. return err
  241. }
  242. if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
  243. return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, ng.withTimeout())
  244. }
  245. return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
  246. ng.conf.KeyFile, router, func(svr *http.Server) {
  247. if ng.tlsConfig != nil {
  248. svr.TLSConfig = ng.tlsConfig
  249. }
  250. }, ng.withTimeout())
  251. }
  252. func (ng *engine) use(middleware Middleware) {
  253. ng.middlewares = append(ng.middlewares, middleware)
  254. }
  255. func (ng *engine) withTimeout() internal.StartOption {
  256. return func(svr *http.Server) {
  257. timeout := ng.conf.Timeout
  258. if timeout > 0 {
  259. // factor 0.8, to avoid clients send longer content-length than the actual content,
  260. // without this timeout setting, the server will time out and respond 503 Service Unavailable,
  261. // which triggers the circuit breaker.
  262. svr.ReadTimeout = 4 * time.Duration(timeout) * time.Millisecond / 5
  263. // factor 1.1, to avoid servers don't have enough time to write responses.
  264. // setting the factor less than 1.0 may lead clients not receiving the responses.
  265. svr.WriteTimeout = 11 * time.Duration(timeout) * time.Millisecond / 10
  266. }
  267. }
  268. }
  269. func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
  270. return func(next http.Handler) http.Handler {
  271. return ware(next.ServeHTTP)
  272. }
  273. }