engine.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "sort"
  8. "time"
  9. "github.com/zeromicro/go-zero/core/codec"
  10. "github.com/zeromicro/go-zero/core/load"
  11. "github.com/zeromicro/go-zero/core/stat"
  12. "github.com/zeromicro/go-zero/rest/chain"
  13. "github.com/zeromicro/go-zero/rest/handler"
  14. "github.com/zeromicro/go-zero/rest/httpx"
  15. "github.com/zeromicro/go-zero/rest/internal"
  16. "github.com/zeromicro/go-zero/rest/internal/response"
  17. )
  18. // use 1000m to represent 100%
  19. const topCpuUsage = 1000
  20. // ErrSignatureConfig is an error that indicates bad config for signature.
  21. var ErrSignatureConfig = errors.New("bad config for Signature")
  22. type engine struct {
  23. conf RestConf
  24. routes []featuredRoutes
  25. // timeout is the max timeout of all routes
  26. timeout time.Duration
  27. unauthorizedCallback handler.UnauthorizedCallback
  28. unsignedCallback handler.UnsignedCallback
  29. chain chain.Chain
  30. middlewares []Middleware
  31. shedder load.Shedder
  32. priorityShedder load.Shedder
  33. tlsConfig *tls.Config
  34. }
  35. func newEngine(c RestConf) *engine {
  36. svr := &engine{
  37. conf: c,
  38. timeout: time.Duration(c.Timeout) * time.Millisecond,
  39. }
  40. if c.CpuThreshold > 0 {
  41. svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
  42. svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
  43. (c.CpuThreshold + topCpuUsage) >> 1))
  44. }
  45. return svr
  46. }
  47. func (ng *engine) addRoutes(r featuredRoutes) {
  48. ng.routes = append(ng.routes, r)
  49. // need to guarantee the timeout is the max of all routes
  50. // otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
  51. if r.timeout > ng.timeout {
  52. ng.timeout = r.timeout
  53. }
  54. }
  55. func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
  56. verifier func(chain.Chain) chain.Chain) chain.Chain {
  57. if fr.jwt.enabled {
  58. if len(fr.jwt.prevSecret) == 0 {
  59. chn = chn.Append(handler.Authorize(fr.jwt.secret,
  60. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  61. } else {
  62. chn = chn.Append(handler.Authorize(fr.jwt.secret,
  63. handler.WithPrevSecret(fr.jwt.prevSecret),
  64. handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
  65. }
  66. }
  67. return verifier(chn)
  68. }
  69. func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
  70. verifier, err := ng.signatureVerifier(fr.signature)
  71. if err != nil {
  72. return err
  73. }
  74. for _, route := range fr.routes {
  75. if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
  76. return err
  77. }
  78. }
  79. return nil
  80. }
  81. func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
  82. route Route, verifier func(chain.Chain) chain.Chain) error {
  83. chn := ng.chain
  84. if chn == nil {
  85. chn = ng.buildChainWithNativeMiddlewares(fr, route, metrics)
  86. }
  87. chn = ng.appendAuthHandler(fr, chn, verifier)
  88. for _, middleware := range ng.middlewares {
  89. chn = chn.Append(convertMiddleware(middleware))
  90. }
  91. handle := chn.ThenFunc(route.Handler)
  92. return router.Handle(route.Method, route.Path, handle)
  93. }
  94. func (ng *engine) bindRoutes(router httpx.Router) error {
  95. metrics := ng.createMetrics()
  96. for _, fr := range ng.routes {
  97. if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
  98. return err
  99. }
  100. }
  101. return nil
  102. }
  103. func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route,
  104. metrics *stat.Metrics) chain.Chain {
  105. chn := chain.New()
  106. if ng.conf.Middlewares.Trace {
  107. chn = chn.Append(handler.TraceHandler(ng.conf.Name,
  108. route.Path,
  109. handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)))
  110. }
  111. if ng.conf.Middlewares.Log {
  112. chn = chn.Append(ng.getLogHandler())
  113. }
  114. if ng.conf.Middlewares.Prometheus {
  115. chn = chn.Append(handler.PrometheusHandler(route.Path, route.Method))
  116. }
  117. if ng.conf.Middlewares.MaxConns {
  118. chn = chn.Append(handler.MaxConnsHandler(ng.conf.MaxConns))
  119. }
  120. if ng.conf.Middlewares.Breaker {
  121. chn = chn.Append(handler.BreakerHandler(route.Method, route.Path, metrics))
  122. }
  123. if ng.conf.Middlewares.Shedding {
  124. chn = chn.Append(handler.SheddingHandler(ng.getShedder(fr.priority), metrics))
  125. }
  126. if ng.conf.Middlewares.Timeout {
  127. chn = chn.Append(handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)))
  128. }
  129. if ng.conf.Middlewares.Recover {
  130. chn = chn.Append(handler.RecoverHandler)
  131. }
  132. if ng.conf.Middlewares.Metrics {
  133. chn = chn.Append(handler.MetricHandler(metrics))
  134. }
  135. if ng.conf.Middlewares.MaxBytes {
  136. chn = chn.Append(handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)))
  137. }
  138. if ng.conf.Middlewares.Gunzip {
  139. chn = chn.Append(handler.GunzipHandler)
  140. }
  141. return chn
  142. }
  143. func (ng *engine) checkedMaxBytes(bytes int64) int64 {
  144. if bytes > 0 {
  145. return bytes
  146. }
  147. return ng.conf.MaxBytes
  148. }
  149. func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
  150. if timeout > 0 {
  151. return timeout
  152. }
  153. return time.Duration(ng.conf.Timeout) * time.Millisecond
  154. }
  155. func (ng *engine) createMetrics() *stat.Metrics {
  156. var metrics *stat.Metrics
  157. if len(ng.conf.Name) > 0 {
  158. metrics = stat.NewMetrics(ng.conf.Name)
  159. } else {
  160. metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
  161. }
  162. return metrics
  163. }
  164. func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
  165. if ng.conf.Verbose {
  166. return handler.DetailedLogHandler
  167. }
  168. return handler.LogHandler
  169. }
  170. func (ng *engine) getShedder(priority bool) load.Shedder {
  171. if priority && ng.priorityShedder != nil {
  172. return ng.priorityShedder
  173. }
  174. return ng.shedder
  175. }
  176. // notFoundHandler returns a middleware that handles 404 not found requests.
  177. func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
  178. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  179. chn := chain.New(
  180. handler.TraceHandler(ng.conf.Name,
  181. "",
  182. handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)),
  183. ng.getLogHandler(),
  184. )
  185. var h http.Handler
  186. if next != nil {
  187. h = chn.Then(next)
  188. } else {
  189. h = chn.Then(http.NotFoundHandler())
  190. }
  191. cw := response.NewHeaderOnceResponseWriter(w)
  192. h.ServeHTTP(cw, r)
  193. cw.WriteHeader(http.StatusNotFound)
  194. })
  195. }
  196. func (ng *engine) print() {
  197. var routes []string
  198. for _, fr := range ng.routes {
  199. for _, route := range fr.routes {
  200. routes = append(routes, fmt.Sprintf("%s %s", route.Method, route.Path))
  201. }
  202. }
  203. sort.Strings(routes)
  204. fmt.Println("Routes:")
  205. for _, route := range routes {
  206. fmt.Printf(" %s\n", route)
  207. }
  208. }
  209. func (ng *engine) setTlsConfig(cfg *tls.Config) {
  210. ng.tlsConfig = cfg
  211. }
  212. func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
  213. ng.unauthorizedCallback = callback
  214. }
  215. func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
  216. ng.unsignedCallback = callback
  217. }
  218. func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
  219. if !signature.enabled {
  220. return func(chn chain.Chain) chain.Chain {
  221. return chn
  222. }, nil
  223. }
  224. if len(signature.PrivateKeys) == 0 {
  225. if signature.Strict {
  226. return nil, ErrSignatureConfig
  227. }
  228. return func(chn chain.Chain) chain.Chain {
  229. return chn
  230. }, nil
  231. }
  232. decrypters := make(map[string]codec.RsaDecrypter)
  233. for _, key := range signature.PrivateKeys {
  234. fingerprint := key.Fingerprint
  235. file := key.KeyFile
  236. decrypter, err := codec.NewRsaDecrypter(file)
  237. if err != nil {
  238. return nil, err
  239. }
  240. decrypters[fingerprint] = decrypter
  241. }
  242. return func(chn chain.Chain) chain.Chain {
  243. if ng.unsignedCallback == nil {
  244. return chn.Append(handler.LimitContentSecurityHandler(ng.conf.MaxBytes,
  245. decrypters, signature.Expiry, signature.Strict))
  246. }
  247. return chn.Append(handler.LimitContentSecurityHandler(ng.conf.MaxBytes,
  248. decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
  249. }, nil
  250. }
  251. func (ng *engine) start(ch chan *http.Server, router httpx.Router, opts ...StartOption) error {
  252. if err := ng.bindRoutes(router); err != nil {
  253. return err
  254. }
  255. // make sure user defined options overwrite default options
  256. opts = append([]StartOption{ng.withTimeout()}, opts...)
  257. if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
  258. return internal.StartHttp(ch, ng.conf.Host, ng.conf.Port, router, opts...)
  259. }
  260. // make sure user defined options overwrite default options
  261. opts = append([]StartOption{
  262. func(svr *http.Server) {
  263. if ng.tlsConfig != nil {
  264. svr.TLSConfig = ng.tlsConfig
  265. }
  266. },
  267. }, opts...)
  268. return internal.StartHttps(ch, ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
  269. ng.conf.KeyFile, router, opts...)
  270. }
  271. func (ng *engine) use(middleware Middleware) {
  272. ng.middlewares = append(ng.middlewares, middleware)
  273. }
  274. func (ng *engine) withTimeout() internal.StartOption {
  275. return func(svr *http.Server) {
  276. timeout := ng.timeout
  277. if timeout > 0 {
  278. // factor 0.8, to avoid clients send longer content-length than the actual content,
  279. // without this timeout setting, the server will time out and respond 503 Service Unavailable,
  280. // which triggers the circuit breaker.
  281. svr.ReadTimeout = 4 * timeout / 5
  282. // factor 1.1, to avoid servers don't have enough time to write responses.
  283. // setting the factor less than 1.0 may lead clients not receiving the responses.
  284. svr.WriteTimeout = 11 * timeout / 10
  285. }
  286. }
  287. }
  288. func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
  289. return func(next http.Handler) http.Handler {
  290. return ware(next.ServeHTTP)
  291. }
  292. }