engine_test.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. package rest
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "net/http/httptest"
  7. "sync/atomic"
  8. "testing"
  9. "time"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/zeromicro/go-zero/core/conf"
  12. "github.com/zeromicro/go-zero/core/logx"
  13. )
  14. func TestNewEngine(t *testing.T) {
  15. yamls := []string{
  16. `Name: foo
  17. Port: 54321
  18. Middlewares:
  19. Log: false
  20. `,
  21. `Name: foo
  22. Port: 54321
  23. CpuThreshold: 500
  24. Middlewares:
  25. Log: false
  26. `,
  27. `Name: foo
  28. Port: 54321
  29. CpuThreshold: 500
  30. Verbose: true
  31. `,
  32. }
  33. routes := []featuredRoutes{
  34. {
  35. jwt: jwtSetting{},
  36. signature: signatureSetting{},
  37. routes: []Route{{
  38. Method: http.MethodGet,
  39. Path: "/",
  40. Handler: func(w http.ResponseWriter, r *http.Request) {},
  41. }},
  42. timeout: time.Minute,
  43. },
  44. {
  45. priority: true,
  46. jwt: jwtSetting{},
  47. signature: signatureSetting{},
  48. routes: []Route{{
  49. Method: http.MethodGet,
  50. Path: "/",
  51. Handler: func(w http.ResponseWriter, r *http.Request) {},
  52. }},
  53. timeout: time.Second,
  54. },
  55. {
  56. priority: true,
  57. jwt: jwtSetting{
  58. enabled: true,
  59. },
  60. signature: signatureSetting{},
  61. routes: []Route{{
  62. Method: http.MethodGet,
  63. Path: "/",
  64. Handler: func(w http.ResponseWriter, r *http.Request) {},
  65. }},
  66. },
  67. {
  68. priority: true,
  69. jwt: jwtSetting{
  70. enabled: true,
  71. prevSecret: "thesecret",
  72. },
  73. signature: signatureSetting{},
  74. routes: []Route{{
  75. Method: http.MethodGet,
  76. Path: "/",
  77. Handler: func(w http.ResponseWriter, r *http.Request) {},
  78. }},
  79. },
  80. {
  81. priority: true,
  82. jwt: jwtSetting{
  83. enabled: true,
  84. },
  85. signature: signatureSetting{},
  86. routes: []Route{{
  87. Method: http.MethodGet,
  88. Path: "/",
  89. Handler: func(w http.ResponseWriter, r *http.Request) {},
  90. }},
  91. },
  92. {
  93. priority: true,
  94. jwt: jwtSetting{
  95. enabled: true,
  96. },
  97. signature: signatureSetting{
  98. enabled: true,
  99. },
  100. routes: []Route{{
  101. Method: http.MethodGet,
  102. Path: "/",
  103. Handler: func(w http.ResponseWriter, r *http.Request) {},
  104. }},
  105. },
  106. {
  107. priority: true,
  108. jwt: jwtSetting{
  109. enabled: true,
  110. },
  111. signature: signatureSetting{
  112. enabled: true,
  113. SignatureConf: SignatureConf{
  114. Strict: true,
  115. },
  116. },
  117. routes: []Route{{
  118. Method: http.MethodGet,
  119. Path: "/",
  120. Handler: func(w http.ResponseWriter, r *http.Request) {},
  121. }},
  122. },
  123. {
  124. priority: true,
  125. jwt: jwtSetting{
  126. enabled: true,
  127. },
  128. signature: signatureSetting{
  129. enabled: true,
  130. SignatureConf: SignatureConf{
  131. Strict: true,
  132. PrivateKeys: []PrivateKeyConf{
  133. {
  134. Fingerprint: "a",
  135. KeyFile: "b",
  136. },
  137. },
  138. },
  139. },
  140. routes: []Route{{
  141. Method: http.MethodGet,
  142. Path: "/",
  143. Handler: func(w http.ResponseWriter, r *http.Request) {},
  144. }},
  145. },
  146. }
  147. for _, yaml := range yamls {
  148. for _, route := range routes {
  149. var cnf RestConf
  150. assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
  151. ng := newEngine(cnf)
  152. ng.addRoutes(route)
  153. ng.use(func(next http.HandlerFunc) http.HandlerFunc {
  154. return func(w http.ResponseWriter, r *http.Request) {
  155. next.ServeHTTP(w, r)
  156. }
  157. })
  158. assert.NotNil(t, ng.start(mockedRouter{}))
  159. timeout := time.Second * 3
  160. if route.timeout > timeout {
  161. timeout = route.timeout
  162. }
  163. assert.Equal(t, timeout, ng.timeout)
  164. }
  165. }
  166. }
  167. func TestEngine_checkedTimeout(t *testing.T) {
  168. tests := []struct {
  169. name string
  170. timeout time.Duration
  171. expect time.Duration
  172. }{
  173. {
  174. name: "not set",
  175. expect: time.Second,
  176. },
  177. {
  178. name: "less",
  179. timeout: time.Millisecond * 500,
  180. expect: time.Millisecond * 500,
  181. },
  182. {
  183. name: "equal",
  184. timeout: time.Second,
  185. expect: time.Second,
  186. },
  187. {
  188. name: "more",
  189. timeout: time.Millisecond * 1500,
  190. expect: time.Millisecond * 1500,
  191. },
  192. }
  193. ng := newEngine(RestConf{
  194. Timeout: 1000,
  195. })
  196. for _, test := range tests {
  197. assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout))
  198. }
  199. }
  200. func TestEngine_checkedMaxBytes(t *testing.T) {
  201. tests := []struct {
  202. name string
  203. maxBytes int64
  204. expect int64
  205. }{
  206. {
  207. name: "not set",
  208. expect: 1000,
  209. },
  210. {
  211. name: "less",
  212. maxBytes: 500,
  213. expect: 500,
  214. },
  215. {
  216. name: "equal",
  217. maxBytes: 1000,
  218. expect: 1000,
  219. },
  220. {
  221. name: "more",
  222. maxBytes: 1500,
  223. expect: 1500,
  224. },
  225. }
  226. ng := newEngine(RestConf{
  227. MaxBytes: 1000,
  228. })
  229. for _, test := range tests {
  230. assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
  231. }
  232. }
  233. func TestEngine_notFoundHandler(t *testing.T) {
  234. logx.Disable()
  235. ng := newEngine(RestConf{})
  236. ts := httptest.NewServer(ng.notFoundHandler(nil))
  237. defer ts.Close()
  238. client := ts.Client()
  239. err := func(_ context.Context) error {
  240. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  241. assert.Nil(t, err)
  242. res, err := client.Do(req)
  243. assert.Nil(t, err)
  244. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  245. return res.Body.Close()
  246. }(context.Background())
  247. assert.Nil(t, err)
  248. }
  249. func TestEngine_notFoundHandlerNotNil(t *testing.T) {
  250. logx.Disable()
  251. ng := newEngine(RestConf{})
  252. var called int32
  253. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  254. atomic.AddInt32(&called, 1)
  255. })))
  256. defer ts.Close()
  257. client := ts.Client()
  258. err := func(_ context.Context) error {
  259. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  260. assert.Nil(t, err)
  261. res, err := client.Do(req)
  262. assert.Nil(t, err)
  263. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  264. return res.Body.Close()
  265. }(context.Background())
  266. assert.Nil(t, err)
  267. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  268. }
  269. func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
  270. logx.Disable()
  271. ng := newEngine(RestConf{})
  272. var called int32
  273. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  274. atomic.AddInt32(&called, 1)
  275. w.WriteHeader(http.StatusExpectationFailed)
  276. })))
  277. defer ts.Close()
  278. client := ts.Client()
  279. err := func(_ context.Context) error {
  280. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  281. assert.Nil(t, err)
  282. res, err := client.Do(req)
  283. assert.Nil(t, err)
  284. assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
  285. return res.Body.Close()
  286. }(context.Background())
  287. assert.Nil(t, err)
  288. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  289. }
  290. func TestEngine_withTimeout(t *testing.T) {
  291. logx.Disable()
  292. tests := []struct {
  293. name string
  294. timeout int64
  295. }{
  296. {
  297. name: "not set",
  298. },
  299. {
  300. name: "set",
  301. timeout: 1000,
  302. },
  303. }
  304. for _, test := range tests {
  305. test := test
  306. t.Run(test.name, func(t *testing.T) {
  307. ng := newEngine(RestConf{Timeout: test.timeout})
  308. svr := &http.Server{}
  309. ng.withTimeout()(svr)
  310. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
  311. assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
  312. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
  313. assert.Equal(t, time.Duration(0), svr.IdleTimeout)
  314. })
  315. }
  316. }
  317. type mockedRouter struct{}
  318. func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
  319. }
  320. func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
  321. return errors.New("foo")
  322. }
  323. func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
  324. }
  325. func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
  326. }