engine_test.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. package rest
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "net/http/httptest"
  7. "sync/atomic"
  8. "testing"
  9. "time"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/zeromicro/go-zero/core/conf"
  12. "github.com/zeromicro/go-zero/core/logx"
  13. )
  14. func TestNewEngine(t *testing.T) {
  15. yamls := []string{
  16. `Name: foo
  17. Port: 54321
  18. `,
  19. `Name: foo
  20. Port: 54321
  21. CpuThreshold: 500
  22. `,
  23. `Name: foo
  24. Port: 54321
  25. CpuThreshold: 500
  26. Verbose: true
  27. `,
  28. }
  29. routes := []featuredRoutes{
  30. {
  31. jwt: jwtSetting{},
  32. signature: signatureSetting{},
  33. routes: []Route{{
  34. Method: http.MethodGet,
  35. Path: "/",
  36. Handler: func(w http.ResponseWriter, r *http.Request) {},
  37. }},
  38. },
  39. {
  40. priority: true,
  41. jwt: jwtSetting{},
  42. signature: signatureSetting{},
  43. routes: []Route{{
  44. Method: http.MethodGet,
  45. Path: "/",
  46. Handler: func(w http.ResponseWriter, r *http.Request) {},
  47. }},
  48. },
  49. {
  50. priority: true,
  51. jwt: jwtSetting{
  52. enabled: true,
  53. },
  54. signature: signatureSetting{},
  55. routes: []Route{{
  56. Method: http.MethodGet,
  57. Path: "/",
  58. Handler: func(w http.ResponseWriter, r *http.Request) {},
  59. }},
  60. },
  61. {
  62. priority: true,
  63. jwt: jwtSetting{
  64. enabled: true,
  65. prevSecret: "thesecret",
  66. },
  67. signature: signatureSetting{},
  68. routes: []Route{{
  69. Method: http.MethodGet,
  70. Path: "/",
  71. Handler: func(w http.ResponseWriter, r *http.Request) {},
  72. }},
  73. },
  74. {
  75. priority: true,
  76. jwt: jwtSetting{
  77. enabled: true,
  78. },
  79. signature: signatureSetting{},
  80. routes: []Route{{
  81. Method: http.MethodGet,
  82. Path: "/",
  83. Handler: func(w http.ResponseWriter, r *http.Request) {},
  84. }},
  85. },
  86. {
  87. priority: true,
  88. jwt: jwtSetting{
  89. enabled: true,
  90. },
  91. signature: signatureSetting{
  92. enabled: true,
  93. },
  94. routes: []Route{{
  95. Method: http.MethodGet,
  96. Path: "/",
  97. Handler: func(w http.ResponseWriter, r *http.Request) {},
  98. }},
  99. },
  100. {
  101. priority: true,
  102. jwt: jwtSetting{
  103. enabled: true,
  104. },
  105. signature: signatureSetting{
  106. enabled: true,
  107. SignatureConf: SignatureConf{
  108. Strict: true,
  109. },
  110. },
  111. routes: []Route{{
  112. Method: http.MethodGet,
  113. Path: "/",
  114. Handler: func(w http.ResponseWriter, r *http.Request) {},
  115. }},
  116. },
  117. {
  118. priority: true,
  119. jwt: jwtSetting{
  120. enabled: true,
  121. },
  122. signature: signatureSetting{
  123. enabled: true,
  124. SignatureConf: SignatureConf{
  125. Strict: true,
  126. PrivateKeys: []PrivateKeyConf{
  127. {
  128. Fingerprint: "a",
  129. KeyFile: "b",
  130. },
  131. },
  132. },
  133. },
  134. routes: []Route{{
  135. Method: http.MethodGet,
  136. Path: "/",
  137. Handler: func(w http.ResponseWriter, r *http.Request) {},
  138. }},
  139. },
  140. }
  141. for _, yaml := range yamls {
  142. for _, route := range routes {
  143. var cnf RestConf
  144. assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
  145. ng := newEngine(cnf)
  146. ng.addRoutes(route)
  147. ng.use(func(next http.HandlerFunc) http.HandlerFunc {
  148. return func(w http.ResponseWriter, r *http.Request) {
  149. next.ServeHTTP(w, r)
  150. }
  151. })
  152. assert.NotNil(t, ng.start(mockedRouter{}))
  153. }
  154. }
  155. }
  156. func TestEngine_checkedTimeout(t *testing.T) {
  157. tests := []struct {
  158. name string
  159. timeout time.Duration
  160. expect time.Duration
  161. }{
  162. {
  163. name: "not set",
  164. expect: time.Second,
  165. },
  166. {
  167. name: "less",
  168. timeout: time.Millisecond * 500,
  169. expect: time.Millisecond * 500,
  170. },
  171. {
  172. name: "equal",
  173. timeout: time.Second,
  174. expect: time.Second,
  175. },
  176. {
  177. name: "more",
  178. timeout: time.Millisecond * 1500,
  179. expect: time.Millisecond * 1500,
  180. },
  181. }
  182. ng := newEngine(RestConf{
  183. Timeout: 1000,
  184. })
  185. for _, test := range tests {
  186. assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout))
  187. }
  188. }
  189. func TestEngine_checkedMaxBytes(t *testing.T) {
  190. tests := []struct {
  191. name string
  192. maxBytes int64
  193. expect int64
  194. }{
  195. {
  196. name: "not set",
  197. expect: 1000,
  198. },
  199. {
  200. name: "less",
  201. maxBytes: 500,
  202. expect: 500,
  203. },
  204. {
  205. name: "equal",
  206. maxBytes: 1000,
  207. expect: 1000,
  208. },
  209. {
  210. name: "more",
  211. maxBytes: 1500,
  212. expect: 1500,
  213. },
  214. }
  215. ng := newEngine(RestConf{
  216. MaxBytes: 1000,
  217. })
  218. for _, test := range tests {
  219. assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
  220. }
  221. }
  222. func TestEngine_notFoundHandler(t *testing.T) {
  223. logx.Disable()
  224. ng := newEngine(RestConf{})
  225. ts := httptest.NewServer(ng.notFoundHandler(nil))
  226. defer ts.Close()
  227. client := ts.Client()
  228. err := func(ctx context.Context) error {
  229. req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
  230. assert.Nil(t, err)
  231. res, err := client.Do(req)
  232. assert.Nil(t, err)
  233. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  234. return res.Body.Close()
  235. }(context.Background())
  236. assert.Nil(t, err)
  237. }
  238. func TestEngine_notFoundHandlerNotNil(t *testing.T) {
  239. logx.Disable()
  240. ng := newEngine(RestConf{})
  241. var called int32
  242. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  243. atomic.AddInt32(&called, 1)
  244. })))
  245. defer ts.Close()
  246. client := ts.Client()
  247. err := func(ctx context.Context) error {
  248. req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
  249. assert.Nil(t, err)
  250. res, err := client.Do(req)
  251. assert.Nil(t, err)
  252. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  253. return res.Body.Close()
  254. }(context.Background())
  255. assert.Nil(t, err)
  256. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  257. }
  258. func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
  259. logx.Disable()
  260. ng := newEngine(RestConf{})
  261. var called int32
  262. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  263. atomic.AddInt32(&called, 1)
  264. w.WriteHeader(http.StatusExpectationFailed)
  265. })))
  266. defer ts.Close()
  267. client := ts.Client()
  268. err := func(ctx context.Context) error {
  269. req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
  270. assert.Nil(t, err)
  271. res, err := client.Do(req)
  272. assert.Nil(t, err)
  273. assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
  274. return res.Body.Close()
  275. }(context.Background())
  276. assert.Nil(t, err)
  277. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  278. }
  279. func TestEngine_withTimeout(t *testing.T) {
  280. logx.Disable()
  281. tests := []struct {
  282. name string
  283. timeout int64
  284. }{
  285. {
  286. name: "not set",
  287. },
  288. {
  289. name: "set",
  290. timeout: 1000,
  291. },
  292. }
  293. for _, test := range tests {
  294. test := test
  295. t.Run(test.name, func(t *testing.T) {
  296. ng := newEngine(RestConf{Timeout: test.timeout})
  297. svr := &http.Server{}
  298. ng.withTimeout()(svr)
  299. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
  300. assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
  301. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*9/10, svr.WriteTimeout)
  302. assert.Equal(t, time.Duration(0), svr.IdleTimeout)
  303. })
  304. }
  305. }
  306. type mockedRouter struct{}
  307. func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
  308. }
  309. func (m mockedRouter) Handle(_, _ string, _ http.Handler) error {
  310. return errors.New("foo")
  311. }
  312. func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
  313. }
  314. func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
  315. }