engine.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "sort"
  8. "time"
  9. "github.com/zeromicro/go-zero/core/codec"
  10. "github.com/zeromicro/go-zero/core/load"
  11. "github.com/zeromicro/go-zero/core/stat"
  12. "github.com/zeromicro/go-zero/rest/chain"
  13. "github.com/zeromicro/go-zero/rest/handler"
  14. "github.com/zeromicro/go-zero/rest/httpx"
  15. "github.com/zeromicro/go-zero/rest/internal"
  16. "github.com/zeromicro/go-zero/rest/internal/response"
  17. )
  18. // use 1000m to represent 100%
  19. const topCpuUsage = 1000
  20. // ErrSignatureConfig is an error that indicates bad config for signature.
  21. var ErrSignatureConfig = errors.New("bad config for Signature")
  22. type engine struct {
  23. conf RestConf
  24. routes []featuredRoutes
  25. unauthorizedCallback handler.UnauthorizedCallback
  26. unsignedCallback handler.UnsignedCallback
  27. chain chain.Chain
  28. middlewares []Middleware
  29. shedder load.Shedder
  30. priorityShedder load.Shedder
  31. tlsConfig *tls.Config
  32. }
  33. func newEngine(c RestConf) *engine {
  34. svr := &engine{
  35. conf: c,
  36. }
  37. if c.CpuThreshold > 0 {
  38. svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
  39. svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
  40. (c.CpuThreshold + topCpuUsage) >> 1))
  41. }
  42. return svr
  43. }
  44. func (ng *engine) addRoutes(r featuredRoutes) {
  45. ng.routes = append(ng.routes, r)
  46. }
  47. func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
  48. verifier func(chain.Chain) chain.Chain) chain.Chain {
  49. if fr.jwt.enabled {
  50. if len(fr.jwt.prevSecret) == 0 {
  51. chn = chn.Append(handler.Authorize(fr.jwt.secret,
  52. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  53. } else {
  54. chn = chn.Append(handler.Authorize(fr.jwt.secret,
  55. handler.WithPrevSecret(fr.jwt.prevSecret),
  56. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  57. }
  58. }
  59. return verifier(chn)
  60. }
  61. func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
  62. verifier, err := ng.signatureVerifier(fr.signature)
  63. if err != nil {
  64. return err
  65. }
  66. for _, route := range fr.routes {
  67. if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
  68. return err
  69. }
  70. }
  71. return nil
  72. }
  73. func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
  74. route Route, verifier func(chain.Chain) chain.Chain) error {
  75. chn := ng.chain
  76. if chn == nil {
  77. chn = ng.buildChainWithNativeMiddlewares(fr, route, metrics)
  78. }
  79. chn = ng.appendAuthHandler(fr, chn, verifier)
  80. for _, middleware := range ng.middlewares {
  81. chn = chn.Append(convertMiddleware(middleware))
  82. }
  83. handle := chn.ThenFunc(route.Handler)
  84. return router.Handle(route.Method, route.Path, handle)
  85. }
  86. func (ng *engine) bindRoutes(router httpx.Router) error {
  87. metrics := ng.createMetrics()
  88. for _, fr := range ng.routes {
  89. if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
  90. return err
  91. }
  92. }
  93. return nil
  94. }
  95. func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route,
  96. metrics *stat.Metrics) chain.Chain {
  97. chn := chain.New()
  98. if ng.conf.Middlewares.Trace {
  99. chn = chn.Append(handler.TraceHandler(ng.conf.Name,
  100. route.Path,
  101. handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)))
  102. }
  103. if ng.conf.Middlewares.Log {
  104. chn = chn.Append(ng.getLogHandler())
  105. }
  106. if ng.conf.Middlewares.Prometheus {
  107. chn = chn.Append(handler.PrometheusHandler(route.Path))
  108. }
  109. if ng.conf.Middlewares.MaxConns {
  110. chn = chn.Append(handler.MaxConnsHandler(ng.conf.MaxConns))
  111. }
  112. if ng.conf.Middlewares.Breaker {
  113. chn = chn.Append(handler.BreakerHandler(route.Method, route.Path, metrics))
  114. }
  115. if ng.conf.Middlewares.Shedding {
  116. chn = chn.Append(handler.SheddingHandler(ng.getShedder(fr.priority), metrics))
  117. }
  118. if ng.conf.Middlewares.Timeout {
  119. chn = chn.Append(handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)))
  120. }
  121. if ng.conf.Middlewares.Recover {
  122. chn = chn.Append(handler.RecoverHandler)
  123. }
  124. if ng.conf.Middlewares.Metrics {
  125. chn = chn.Append(handler.MetricHandler(metrics))
  126. }
  127. if ng.conf.Middlewares.MaxBytes {
  128. chn = chn.Append(handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)))
  129. }
  130. if ng.conf.Middlewares.Gunzip {
  131. chn = chn.Append(handler.GunzipHandler)
  132. }
  133. return chn
  134. }
  135. func (ng *engine) checkedMaxBytes(bytes int64) int64 {
  136. if bytes > 0 {
  137. return bytes
  138. }
  139. return ng.conf.MaxBytes
  140. }
  141. func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
  142. if timeout > 0 {
  143. return timeout
  144. }
  145. return time.Duration(ng.conf.Timeout) * time.Millisecond
  146. }
  147. func (ng *engine) createMetrics() *stat.Metrics {
  148. var metrics *stat.Metrics
  149. if len(ng.conf.Name) > 0 {
  150. metrics = stat.NewMetrics(ng.conf.Name)
  151. } else {
  152. metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
  153. }
  154. return metrics
  155. }
  156. func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
  157. if ng.conf.Verbose {
  158. return handler.DetailedLogHandler
  159. }
  160. return handler.LogHandler
  161. }
  162. func (ng *engine) getShedder(priority bool) load.Shedder {
  163. if priority && ng.priorityShedder != nil {
  164. return ng.priorityShedder
  165. }
  166. return ng.shedder
  167. }
  168. // notFoundHandler returns a middleware that handles 404 not found requests.
  169. func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
  170. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  171. chn := chain.New(
  172. handler.TraceHandler(ng.conf.Name,
  173. "",
  174. handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)),
  175. ng.getLogHandler(),
  176. )
  177. var h http.Handler
  178. if next != nil {
  179. h = chn.Then(next)
  180. } else {
  181. h = chn.Then(http.NotFoundHandler())
  182. }
  183. cw := response.NewHeaderOnceResponseWriter(w)
  184. h.ServeHTTP(cw, r)
  185. cw.WriteHeader(http.StatusNotFound)
  186. })
  187. }
  188. func (ng *engine) print() {
  189. var routes []string
  190. for _, fr := range ng.routes {
  191. for _, route := range fr.routes {
  192. routes = append(routes, fmt.Sprintf("%s %s", route.Method, route.Path))
  193. }
  194. }
  195. sort.Strings(routes)
  196. fmt.Println("Routes:")
  197. for _, route := range routes {
  198. fmt.Printf(" %s\n", route)
  199. }
  200. }
  201. func (ng *engine) setTlsConfig(cfg *tls.Config) {
  202. ng.tlsConfig = cfg
  203. }
  204. func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
  205. ng.unauthorizedCallback = callback
  206. }
  207. func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
  208. ng.unsignedCallback = callback
  209. }
  210. func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
  211. if !signature.enabled {
  212. return func(chn chain.Chain) chain.Chain {
  213. return chn
  214. }, nil
  215. }
  216. if len(signature.PrivateKeys) == 0 {
  217. if signature.Strict {
  218. return nil, ErrSignatureConfig
  219. }
  220. return func(chn chain.Chain) chain.Chain {
  221. return chn
  222. }, nil
  223. }
  224. decrypters := make(map[string]codec.RsaDecrypter)
  225. for _, key := range signature.PrivateKeys {
  226. fingerprint := key.Fingerprint
  227. file := key.KeyFile
  228. decrypter, err := codec.NewRsaDecrypter(file)
  229. if err != nil {
  230. return nil, err
  231. }
  232. decrypters[fingerprint] = decrypter
  233. }
  234. return func(chn chain.Chain) chain.Chain {
  235. if ng.unsignedCallback != nil {
  236. return chn.Append(handler.ContentSecurityHandler(
  237. decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
  238. }
  239. return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict))
  240. }, nil
  241. }
  242. func (ng *engine) start(router httpx.Router) error {
  243. if err := ng.bindRoutes(router); err != nil {
  244. return err
  245. }
  246. if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
  247. return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, ng.withTimeout())
  248. }
  249. return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
  250. ng.conf.KeyFile, router, func(svr *http.Server) {
  251. if ng.tlsConfig != nil {
  252. svr.TLSConfig = ng.tlsConfig
  253. }
  254. }, ng.withTimeout())
  255. }
  256. func (ng *engine) use(middleware Middleware) {
  257. ng.middlewares = append(ng.middlewares, middleware)
  258. }
  259. func (ng *engine) withTimeout() internal.StartOption {
  260. return func(svr *http.Server) {
  261. timeout := ng.conf.Timeout
  262. if timeout > 0 {
  263. // factor 0.8, to avoid clients send longer content-length than the actual content,
  264. // without this timeout setting, the server will time out and respond 503 Service Unavailable,
  265. // which triggers the circuit breaker.
  266. svr.ReadTimeout = 4 * time.Duration(timeout) * time.Millisecond / 5
  267. // factor 1.1, to avoid servers don't have enough time to write responses.
  268. // setting the factor less than 1.0 may lead clients not receiving the responses.
  269. svr.WriteTimeout = 11 * time.Duration(timeout) * time.Millisecond / 10
  270. }
  271. }
  272. }
  273. func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
  274. return func(next http.Handler) http.Handler {
  275. return ware(next.ServeHTTP)
  276. }
  277. }