engine_test.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. package rest
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "net/http/httptest"
  7. "sync/atomic"
  8. "testing"
  9. "time"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/zeromicro/go-zero/core/conf"
  12. "github.com/zeromicro/go-zero/core/logx"
  13. )
  14. func TestNewEngine(t *testing.T) {
  15. yamls := []string{
  16. `Name: foo
  17. Port: 54321
  18. Middlewares:
  19. Log: false
  20. `,
  21. `Name: foo
  22. Port: 54321
  23. CpuThreshold: 500
  24. Middlewares:
  25. Log: false
  26. `,
  27. `Name: foo
  28. Port: 54321
  29. CpuThreshold: 500
  30. Verbose: true
  31. `,
  32. }
  33. routes := []featuredRoutes{
  34. {
  35. jwt: jwtSetting{},
  36. signature: signatureSetting{},
  37. routes: []Route{{
  38. Method: http.MethodGet,
  39. Path: "/",
  40. Handler: func(w http.ResponseWriter, r *http.Request) {},
  41. }},
  42. },
  43. {
  44. priority: true,
  45. jwt: jwtSetting{},
  46. signature: signatureSetting{},
  47. routes: []Route{{
  48. Method: http.MethodGet,
  49. Path: "/",
  50. Handler: func(w http.ResponseWriter, r *http.Request) {},
  51. }},
  52. },
  53. {
  54. priority: true,
  55. jwt: jwtSetting{
  56. enabled: true,
  57. },
  58. signature: signatureSetting{},
  59. routes: []Route{{
  60. Method: http.MethodGet,
  61. Path: "/",
  62. Handler: func(w http.ResponseWriter, r *http.Request) {},
  63. }},
  64. },
  65. {
  66. priority: true,
  67. jwt: jwtSetting{
  68. enabled: true,
  69. prevSecret: "thesecret",
  70. },
  71. signature: signatureSetting{},
  72. routes: []Route{{
  73. Method: http.MethodGet,
  74. Path: "/",
  75. Handler: func(w http.ResponseWriter, r *http.Request) {},
  76. }},
  77. },
  78. {
  79. priority: true,
  80. jwt: jwtSetting{
  81. enabled: true,
  82. },
  83. signature: signatureSetting{},
  84. routes: []Route{{
  85. Method: http.MethodGet,
  86. Path: "/",
  87. Handler: func(w http.ResponseWriter, r *http.Request) {},
  88. }},
  89. },
  90. {
  91. priority: true,
  92. jwt: jwtSetting{
  93. enabled: true,
  94. },
  95. signature: signatureSetting{
  96. enabled: true,
  97. },
  98. routes: []Route{{
  99. Method: http.MethodGet,
  100. Path: "/",
  101. Handler: func(w http.ResponseWriter, r *http.Request) {},
  102. }},
  103. },
  104. {
  105. priority: true,
  106. jwt: jwtSetting{
  107. enabled: true,
  108. },
  109. signature: signatureSetting{
  110. enabled: true,
  111. SignatureConf: SignatureConf{
  112. Strict: true,
  113. },
  114. },
  115. routes: []Route{{
  116. Method: http.MethodGet,
  117. Path: "/",
  118. Handler: func(w http.ResponseWriter, r *http.Request) {},
  119. }},
  120. },
  121. {
  122. priority: true,
  123. jwt: jwtSetting{
  124. enabled: true,
  125. },
  126. signature: signatureSetting{
  127. enabled: true,
  128. SignatureConf: SignatureConf{
  129. Strict: true,
  130. PrivateKeys: []PrivateKeyConf{
  131. {
  132. Fingerprint: "a",
  133. KeyFile: "b",
  134. },
  135. },
  136. },
  137. },
  138. routes: []Route{{
  139. Method: http.MethodGet,
  140. Path: "/",
  141. Handler: func(w http.ResponseWriter, r *http.Request) {},
  142. }},
  143. },
  144. }
  145. for _, yaml := range yamls {
  146. for _, route := range routes {
  147. var cnf RestConf
  148. assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
  149. ng := newEngine(cnf)
  150. ng.addRoutes(route)
  151. ng.use(func(next http.HandlerFunc) http.HandlerFunc {
  152. return func(w http.ResponseWriter, r *http.Request) {
  153. next.ServeHTTP(w, r)
  154. }
  155. })
  156. assert.NotNil(t, ng.start(mockedRouter{}))
  157. }
  158. }
  159. }
  160. func TestEngine_checkedTimeout(t *testing.T) {
  161. tests := []struct {
  162. name string
  163. timeout time.Duration
  164. expect time.Duration
  165. }{
  166. {
  167. name: "not set",
  168. expect: time.Second,
  169. },
  170. {
  171. name: "less",
  172. timeout: time.Millisecond * 500,
  173. expect: time.Millisecond * 500,
  174. },
  175. {
  176. name: "equal",
  177. timeout: time.Second,
  178. expect: time.Second,
  179. },
  180. {
  181. name: "more",
  182. timeout: time.Millisecond * 1500,
  183. expect: time.Millisecond * 1500,
  184. },
  185. }
  186. ng := newEngine(RestConf{
  187. Timeout: 1000,
  188. })
  189. for _, test := range tests {
  190. assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout))
  191. }
  192. }
  193. func TestEngine_checkedMaxBytes(t *testing.T) {
  194. tests := []struct {
  195. name string
  196. maxBytes int64
  197. expect int64
  198. }{
  199. {
  200. name: "not set",
  201. expect: 1000,
  202. },
  203. {
  204. name: "less",
  205. maxBytes: 500,
  206. expect: 500,
  207. },
  208. {
  209. name: "equal",
  210. maxBytes: 1000,
  211. expect: 1000,
  212. },
  213. {
  214. name: "more",
  215. maxBytes: 1500,
  216. expect: 1500,
  217. },
  218. }
  219. ng := newEngine(RestConf{
  220. MaxBytes: 1000,
  221. })
  222. for _, test := range tests {
  223. assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
  224. }
  225. }
  226. func TestEngine_notFoundHandler(t *testing.T) {
  227. logx.Disable()
  228. ng := newEngine(RestConf{})
  229. ts := httptest.NewServer(ng.notFoundHandler(nil))
  230. defer ts.Close()
  231. client := ts.Client()
  232. err := func(_ context.Context) error {
  233. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  234. assert.Nil(t, err)
  235. res, err := client.Do(req)
  236. assert.Nil(t, err)
  237. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  238. return res.Body.Close()
  239. }(context.Background())
  240. assert.Nil(t, err)
  241. }
  242. func TestEngine_notFoundHandlerNotNil(t *testing.T) {
  243. logx.Disable()
  244. ng := newEngine(RestConf{})
  245. var called int32
  246. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  247. atomic.AddInt32(&called, 1)
  248. })))
  249. defer ts.Close()
  250. client := ts.Client()
  251. err := func(_ context.Context) error {
  252. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  253. assert.Nil(t, err)
  254. res, err := client.Do(req)
  255. assert.Nil(t, err)
  256. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  257. return res.Body.Close()
  258. }(context.Background())
  259. assert.Nil(t, err)
  260. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  261. }
  262. func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
  263. logx.Disable()
  264. ng := newEngine(RestConf{})
  265. var called int32
  266. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  267. atomic.AddInt32(&called, 1)
  268. w.WriteHeader(http.StatusExpectationFailed)
  269. })))
  270. defer ts.Close()
  271. client := ts.Client()
  272. err := func(_ context.Context) error {
  273. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  274. assert.Nil(t, err)
  275. res, err := client.Do(req)
  276. assert.Nil(t, err)
  277. assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
  278. return res.Body.Close()
  279. }(context.Background())
  280. assert.Nil(t, err)
  281. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  282. }
  283. func TestEngine_withTimeout(t *testing.T) {
  284. logx.Disable()
  285. tests := []struct {
  286. name string
  287. timeout int64
  288. }{
  289. {
  290. name: "not set",
  291. },
  292. {
  293. name: "set",
  294. timeout: 1000,
  295. },
  296. }
  297. for _, test := range tests {
  298. test := test
  299. t.Run(test.name, func(t *testing.T) {
  300. ng := newEngine(RestConf{Timeout: test.timeout})
  301. svr := &http.Server{}
  302. ng.withTimeout()(svr)
  303. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
  304. assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
  305. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
  306. assert.Equal(t, time.Duration(0), svr.IdleTimeout)
  307. })
  308. }
  309. }
  310. type mockedRouter struct{}
  311. func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
  312. }
  313. func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
  314. return errors.New("foo")
  315. }
  316. func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
  317. }
  318. func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
  319. }