engine_test.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. package rest
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "net/http/httptest"
  8. "sync/atomic"
  9. "testing"
  10. "time"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/zeromicro/go-zero/core/conf"
  13. "github.com/zeromicro/go-zero/core/logx"
  14. )
  15. func TestNewEngine(t *testing.T) {
  16. yamls := []string{
  17. `Name: foo
  18. Host: localhost
  19. Port: 0
  20. Middlewares:
  21. Log: false
  22. `,
  23. `Name: foo
  24. Host: localhost
  25. Port: 0
  26. CpuThreshold: 500
  27. Middlewares:
  28. Log: false
  29. `,
  30. `Name: foo
  31. Host: localhost
  32. Port: 0
  33. CpuThreshold: 500
  34. Verbose: true
  35. `,
  36. }
  37. routes := []featuredRoutes{
  38. {
  39. jwt: jwtSetting{},
  40. signature: signatureSetting{},
  41. routes: []Route{{
  42. Method: http.MethodGet,
  43. Path: "/",
  44. Handler: func(w http.ResponseWriter, r *http.Request) {},
  45. }},
  46. timeout: time.Minute,
  47. },
  48. {
  49. priority: true,
  50. jwt: jwtSetting{},
  51. signature: signatureSetting{},
  52. routes: []Route{{
  53. Method: http.MethodGet,
  54. Path: "/",
  55. Handler: func(w http.ResponseWriter, r *http.Request) {},
  56. }},
  57. timeout: time.Second,
  58. },
  59. {
  60. priority: true,
  61. jwt: jwtSetting{
  62. enabled: true,
  63. },
  64. signature: signatureSetting{},
  65. routes: []Route{{
  66. Method: http.MethodGet,
  67. Path: "/",
  68. Handler: func(w http.ResponseWriter, r *http.Request) {},
  69. }},
  70. },
  71. {
  72. priority: true,
  73. jwt: jwtSetting{
  74. enabled: true,
  75. prevSecret: "thesecret",
  76. },
  77. signature: signatureSetting{},
  78. routes: []Route{{
  79. Method: http.MethodGet,
  80. Path: "/",
  81. Handler: func(w http.ResponseWriter, r *http.Request) {},
  82. }},
  83. },
  84. {
  85. priority: true,
  86. jwt: jwtSetting{
  87. enabled: true,
  88. },
  89. signature: signatureSetting{},
  90. routes: []Route{{
  91. Method: http.MethodGet,
  92. Path: "/",
  93. Handler: func(w http.ResponseWriter, r *http.Request) {},
  94. }},
  95. },
  96. {
  97. priority: true,
  98. jwt: jwtSetting{
  99. enabled: true,
  100. },
  101. signature: signatureSetting{
  102. enabled: true,
  103. },
  104. routes: []Route{{
  105. Method: http.MethodGet,
  106. Path: "/",
  107. Handler: func(w http.ResponseWriter, r *http.Request) {},
  108. }},
  109. },
  110. {
  111. priority: true,
  112. jwt: jwtSetting{
  113. enabled: true,
  114. },
  115. signature: signatureSetting{
  116. enabled: true,
  117. SignatureConf: SignatureConf{
  118. Strict: true,
  119. },
  120. },
  121. routes: []Route{{
  122. Method: http.MethodGet,
  123. Path: "/",
  124. Handler: func(w http.ResponseWriter, r *http.Request) {},
  125. }},
  126. },
  127. {
  128. priority: true,
  129. jwt: jwtSetting{
  130. enabled: true,
  131. },
  132. signature: signatureSetting{
  133. enabled: true,
  134. SignatureConf: SignatureConf{
  135. Strict: true,
  136. PrivateKeys: []PrivateKeyConf{
  137. {
  138. Fingerprint: "a",
  139. KeyFile: "b",
  140. },
  141. },
  142. },
  143. },
  144. routes: []Route{{
  145. Method: http.MethodGet,
  146. Path: "/",
  147. Handler: func(w http.ResponseWriter, r *http.Request) {},
  148. }},
  149. },
  150. }
  151. for _, yaml := range yamls {
  152. yaml := yaml
  153. for _, route := range routes {
  154. route := route
  155. t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
  156. var cnf RestConf
  157. assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
  158. ng := newEngine(cnf)
  159. ng.addRoutes(route)
  160. ng.use(func(next http.HandlerFunc) http.HandlerFunc {
  161. return func(w http.ResponseWriter, r *http.Request) {
  162. next.ServeHTTP(w, r)
  163. }
  164. })
  165. assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
  166. }))
  167. timeout := time.Second * 3
  168. if route.timeout > timeout {
  169. timeout = route.timeout
  170. }
  171. assert.Equal(t, timeout, ng.timeout)
  172. })
  173. }
  174. }
  175. }
  176. func TestEngine_checkedTimeout(t *testing.T) {
  177. tests := []struct {
  178. name string
  179. timeout time.Duration
  180. expect time.Duration
  181. }{
  182. {
  183. name: "not set",
  184. expect: time.Second,
  185. },
  186. {
  187. name: "less",
  188. timeout: time.Millisecond * 500,
  189. expect: time.Millisecond * 500,
  190. },
  191. {
  192. name: "equal",
  193. timeout: time.Second,
  194. expect: time.Second,
  195. },
  196. {
  197. name: "more",
  198. timeout: time.Millisecond * 1500,
  199. expect: time.Millisecond * 1500,
  200. },
  201. }
  202. ng := newEngine(RestConf{
  203. Timeout: 1000,
  204. })
  205. for _, test := range tests {
  206. assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout))
  207. }
  208. }
  209. func TestEngine_checkedMaxBytes(t *testing.T) {
  210. tests := []struct {
  211. name string
  212. maxBytes int64
  213. expect int64
  214. }{
  215. {
  216. name: "not set",
  217. expect: 1000,
  218. },
  219. {
  220. name: "less",
  221. maxBytes: 500,
  222. expect: 500,
  223. },
  224. {
  225. name: "equal",
  226. maxBytes: 1000,
  227. expect: 1000,
  228. },
  229. {
  230. name: "more",
  231. maxBytes: 1500,
  232. expect: 1500,
  233. },
  234. }
  235. ng := newEngine(RestConf{
  236. MaxBytes: 1000,
  237. })
  238. for _, test := range tests {
  239. assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
  240. }
  241. }
  242. func TestEngine_notFoundHandler(t *testing.T) {
  243. logx.Disable()
  244. ng := newEngine(RestConf{})
  245. ts := httptest.NewServer(ng.notFoundHandler(nil))
  246. defer ts.Close()
  247. client := ts.Client()
  248. err := func(_ context.Context) error {
  249. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  250. assert.Nil(t, err)
  251. res, err := client.Do(req)
  252. assert.Nil(t, err)
  253. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  254. return res.Body.Close()
  255. }(context.Background())
  256. assert.Nil(t, err)
  257. }
  258. func TestEngine_notFoundHandlerNotNil(t *testing.T) {
  259. logx.Disable()
  260. ng := newEngine(RestConf{})
  261. var called int32
  262. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  263. atomic.AddInt32(&called, 1)
  264. })))
  265. defer ts.Close()
  266. client := ts.Client()
  267. err := func(_ context.Context) error {
  268. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  269. assert.Nil(t, err)
  270. res, err := client.Do(req)
  271. assert.Nil(t, err)
  272. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  273. return res.Body.Close()
  274. }(context.Background())
  275. assert.Nil(t, err)
  276. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  277. }
  278. func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
  279. logx.Disable()
  280. ng := newEngine(RestConf{})
  281. var called int32
  282. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  283. atomic.AddInt32(&called, 1)
  284. w.WriteHeader(http.StatusExpectationFailed)
  285. })))
  286. defer ts.Close()
  287. client := ts.Client()
  288. err := func(_ context.Context) error {
  289. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  290. assert.Nil(t, err)
  291. res, err := client.Do(req)
  292. assert.Nil(t, err)
  293. assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
  294. return res.Body.Close()
  295. }(context.Background())
  296. assert.Nil(t, err)
  297. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  298. }
  299. func TestEngine_withTimeout(t *testing.T) {
  300. logx.Disable()
  301. tests := []struct {
  302. name string
  303. timeout int64
  304. }{
  305. {
  306. name: "not set",
  307. },
  308. {
  309. name: "set",
  310. timeout: 1000,
  311. },
  312. }
  313. for _, test := range tests {
  314. test := test
  315. t.Run(test.name, func(t *testing.T) {
  316. ng := newEngine(RestConf{Timeout: test.timeout})
  317. svr := &http.Server{}
  318. ng.withTimeout()(svr)
  319. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
  320. assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
  321. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
  322. assert.Equal(t, time.Duration(0), svr.IdleTimeout)
  323. })
  324. }
  325. }
  326. type mockedRouter struct {
  327. }
  328. func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
  329. }
  330. func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
  331. return errors.New("foo")
  332. }
  333. func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
  334. }
  335. func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
  336. }