engine_test.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. package rest
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "net/http/httptest"
  9. "os"
  10. "sync/atomic"
  11. "testing"
  12. "time"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/zeromicro/go-zero/core/conf"
  15. "github.com/zeromicro/go-zero/core/fs"
  16. "github.com/zeromicro/go-zero/core/logx"
  17. "github.com/zeromicro/go-zero/rest/router"
  18. )
  19. const (
  20. priKey = `-----BEGIN RSA PRIVATE KEY-----
  21. MIICXQIBAAKBgQC4TJk3onpqb2RYE3wwt23J9SHLFstHGSkUYFLe+nl1dEKHbD+/
  22. Zt95L757J3xGTrwoTc7KCTxbrgn+stn0w52BNjj/kIE2ko4lbh/v8Fl14AyVR9ms
  23. fKtKOnhe5FCT72mdtApr+qvzcC3q9hfXwkyQU32pv7q5UimZ205iKSBmgQIDAQAB
  24. AoGAM5mWqGIAXj5z3MkP01/4CDxuyrrGDVD5FHBno3CDgyQa4Gmpa4B0/ywj671B
  25. aTnwKmSmiiCN2qleuQYASixes2zY5fgTzt+7KNkl9JHsy7i606eH2eCKzsUa/s6u
  26. WD8V3w/hGCQ9zYI18ihwyXlGHIgcRz/eeRh+nWcWVJzGOPUCQQD5nr6It/1yHb1p
  27. C6l4fC4xXF19l4KxJjGu1xv/sOpSx0pOqBDEX3Mh//FU954392rUWDXV1/I65BPt
  28. TLphdsu3AkEAvQJ2Qay/lffFj9FaUrvXuftJZ/Ypn0FpaSiUh3Ak3obBT6UvSZS0
  29. bcYdCJCNHDtBOsWHnIN1x+BcWAPrdU7PhwJBAIQ0dUlH2S3VXnoCOTGc44I1Hzbj
  30. Rc65IdsuBqA3fQN2lX5vOOIog3vgaFrOArg1jBkG1wx5IMvb/EnUN2pjVqUCQCza
  31. KLXtCInOAlPemlCHwumfeAvznmzsWNdbieOZ+SXVVIpR6KbNYwOpv7oIk3Pfm9sW
  32. hNffWlPUKhW42Gc+DIECQQDmk20YgBXwXWRM5DRPbhisIV088N5Z58K9DtFWkZsd
  33. OBDT3dFcgZONtlmR1MqZO0pTh30lA4qovYj3Bx7A8i36
  34. -----END RSA PRIVATE KEY-----`
  35. )
  36. func TestNewEngine(t *testing.T) {
  37. priKeyfile, err := fs.TempFilenameWithText(priKey)
  38. assert.Nil(t, err)
  39. defer os.Remove(priKeyfile)
  40. yamls := []string{
  41. `Name: foo
  42. Host: localhost
  43. Port: 0
  44. Middlewares:
  45. Log: false
  46. `,
  47. `Name: foo
  48. Host: localhost
  49. Port: 0
  50. CpuThreshold: 500
  51. Middlewares:
  52. Log: false
  53. `,
  54. `Name: foo
  55. Host: localhost
  56. Port: 0
  57. CpuThreshold: 500
  58. Verbose: true
  59. `,
  60. }
  61. routes := []featuredRoutes{
  62. {
  63. jwt: jwtSetting{},
  64. signature: signatureSetting{},
  65. routes: []Route{{
  66. Method: http.MethodGet,
  67. Path: "/",
  68. Handler: func(w http.ResponseWriter, r *http.Request) {},
  69. }},
  70. timeout: time.Minute,
  71. },
  72. {
  73. priority: true,
  74. jwt: jwtSetting{},
  75. signature: signatureSetting{},
  76. routes: []Route{{
  77. Method: http.MethodGet,
  78. Path: "/",
  79. Handler: func(w http.ResponseWriter, r *http.Request) {},
  80. }},
  81. timeout: time.Second,
  82. },
  83. {
  84. priority: true,
  85. jwt: jwtSetting{
  86. enabled: true,
  87. },
  88. signature: signatureSetting{},
  89. routes: []Route{{
  90. Method: http.MethodGet,
  91. Path: "/",
  92. Handler: func(w http.ResponseWriter, r *http.Request) {},
  93. }},
  94. },
  95. {
  96. priority: true,
  97. jwt: jwtSetting{
  98. enabled: true,
  99. prevSecret: "thesecret",
  100. },
  101. signature: signatureSetting{},
  102. routes: []Route{{
  103. Method: http.MethodGet,
  104. Path: "/",
  105. Handler: func(w http.ResponseWriter, r *http.Request) {},
  106. }},
  107. },
  108. {
  109. priority: true,
  110. jwt: jwtSetting{
  111. enabled: true,
  112. },
  113. signature: signatureSetting{},
  114. routes: []Route{{
  115. Method: http.MethodGet,
  116. Path: "/",
  117. Handler: func(w http.ResponseWriter, r *http.Request) {},
  118. }},
  119. },
  120. {
  121. priority: true,
  122. jwt: jwtSetting{
  123. enabled: true,
  124. },
  125. signature: signatureSetting{
  126. enabled: true,
  127. },
  128. routes: []Route{{
  129. Method: http.MethodGet,
  130. Path: "/",
  131. Handler: func(w http.ResponseWriter, r *http.Request) {},
  132. }},
  133. },
  134. {
  135. priority: true,
  136. jwt: jwtSetting{
  137. enabled: true,
  138. },
  139. signature: signatureSetting{
  140. enabled: true,
  141. SignatureConf: SignatureConf{
  142. Strict: true,
  143. },
  144. },
  145. routes: []Route{{
  146. Method: http.MethodGet,
  147. Path: "/",
  148. Handler: func(w http.ResponseWriter, r *http.Request) {},
  149. }},
  150. },
  151. {
  152. priority: true,
  153. jwt: jwtSetting{
  154. enabled: true,
  155. },
  156. signature: signatureSetting{
  157. enabled: true,
  158. SignatureConf: SignatureConf{
  159. Strict: true,
  160. PrivateKeys: []PrivateKeyConf{
  161. {
  162. Fingerprint: "a",
  163. KeyFile: "b",
  164. },
  165. },
  166. },
  167. },
  168. routes: []Route{{
  169. Method: http.MethodGet,
  170. Path: "/",
  171. Handler: func(w http.ResponseWriter, r *http.Request) {},
  172. }},
  173. },
  174. {
  175. priority: true,
  176. jwt: jwtSetting{
  177. enabled: true,
  178. },
  179. signature: signatureSetting{
  180. enabled: true,
  181. SignatureConf: SignatureConf{
  182. Strict: true,
  183. PrivateKeys: []PrivateKeyConf{
  184. {
  185. Fingerprint: "a",
  186. KeyFile: priKeyfile,
  187. },
  188. },
  189. },
  190. },
  191. routes: []Route{{
  192. Method: http.MethodGet,
  193. Path: "/",
  194. Handler: func(w http.ResponseWriter, r *http.Request) {},
  195. }},
  196. },
  197. }
  198. var index int32
  199. for _, yaml := range yamls {
  200. yaml := yaml
  201. for _, route := range routes {
  202. route := route
  203. t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
  204. var cnf RestConf
  205. assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
  206. ng := newEngine(cnf)
  207. if atomic.AddInt32(&index, 1)%2 == 0 {
  208. ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
  209. next http.Handler, strict bool, code int) {
  210. })
  211. }
  212. ng.addRoutes(route)
  213. ng.use(func(next http.HandlerFunc) http.HandlerFunc {
  214. return func(w http.ResponseWriter, r *http.Request) {
  215. next.ServeHTTP(w, r)
  216. }
  217. })
  218. assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
  219. }))
  220. timeout := time.Second * 3
  221. if route.timeout > timeout {
  222. timeout = route.timeout
  223. }
  224. assert.Equal(t, timeout, ng.timeout)
  225. })
  226. }
  227. }
  228. }
  229. func TestEngine_checkedTimeout(t *testing.T) {
  230. tests := []struct {
  231. name string
  232. timeout time.Duration
  233. expect time.Duration
  234. }{
  235. {
  236. name: "not set",
  237. expect: time.Second,
  238. },
  239. {
  240. name: "less",
  241. timeout: time.Millisecond * 500,
  242. expect: time.Millisecond * 500,
  243. },
  244. {
  245. name: "equal",
  246. timeout: time.Second,
  247. expect: time.Second,
  248. },
  249. {
  250. name: "more",
  251. timeout: time.Millisecond * 1500,
  252. expect: time.Millisecond * 1500,
  253. },
  254. }
  255. ng := newEngine(RestConf{
  256. Timeout: 1000,
  257. })
  258. for _, test := range tests {
  259. assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout))
  260. }
  261. }
  262. func TestEngine_checkedMaxBytes(t *testing.T) {
  263. tests := []struct {
  264. name string
  265. maxBytes int64
  266. expect int64
  267. }{
  268. {
  269. name: "not set",
  270. expect: 1000,
  271. },
  272. {
  273. name: "less",
  274. maxBytes: 500,
  275. expect: 500,
  276. },
  277. {
  278. name: "equal",
  279. maxBytes: 1000,
  280. expect: 1000,
  281. },
  282. {
  283. name: "more",
  284. maxBytes: 1500,
  285. expect: 1500,
  286. },
  287. }
  288. ng := newEngine(RestConf{
  289. MaxBytes: 1000,
  290. })
  291. for _, test := range tests {
  292. assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
  293. }
  294. }
  295. func TestEngine_notFoundHandler(t *testing.T) {
  296. logx.Disable()
  297. ng := newEngine(RestConf{})
  298. ts := httptest.NewServer(ng.notFoundHandler(nil))
  299. defer ts.Close()
  300. client := ts.Client()
  301. err := func(_ context.Context) error {
  302. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  303. assert.Nil(t, err)
  304. res, err := client.Do(req)
  305. assert.Nil(t, err)
  306. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  307. return res.Body.Close()
  308. }(context.Background())
  309. assert.Nil(t, err)
  310. }
  311. func TestEngine_notFoundHandlerNotNil(t *testing.T) {
  312. logx.Disable()
  313. ng := newEngine(RestConf{})
  314. var called int32
  315. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  316. atomic.AddInt32(&called, 1)
  317. })))
  318. defer ts.Close()
  319. client := ts.Client()
  320. err := func(_ context.Context) error {
  321. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  322. assert.Nil(t, err)
  323. res, err := client.Do(req)
  324. assert.Nil(t, err)
  325. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  326. return res.Body.Close()
  327. }(context.Background())
  328. assert.Nil(t, err)
  329. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  330. }
  331. func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
  332. logx.Disable()
  333. ng := newEngine(RestConf{})
  334. var called int32
  335. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  336. atomic.AddInt32(&called, 1)
  337. w.WriteHeader(http.StatusExpectationFailed)
  338. })))
  339. defer ts.Close()
  340. client := ts.Client()
  341. err := func(_ context.Context) error {
  342. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  343. assert.Nil(t, err)
  344. res, err := client.Do(req)
  345. assert.Nil(t, err)
  346. assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
  347. return res.Body.Close()
  348. }(context.Background())
  349. assert.Nil(t, err)
  350. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  351. }
  352. func TestEngine_withTimeout(t *testing.T) {
  353. logx.Disable()
  354. tests := []struct {
  355. name string
  356. timeout int64
  357. }{
  358. {
  359. name: "not set",
  360. },
  361. {
  362. name: "set",
  363. timeout: 1000,
  364. },
  365. }
  366. for _, test := range tests {
  367. test := test
  368. t.Run(test.name, func(t *testing.T) {
  369. ng := newEngine(RestConf{Timeout: test.timeout})
  370. svr := &http.Server{}
  371. ng.withTimeout()(svr)
  372. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
  373. assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
  374. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
  375. assert.Equal(t, time.Duration(0), svr.IdleTimeout)
  376. })
  377. }
  378. }
  379. func TestEngine_start(t *testing.T) {
  380. logx.Disable()
  381. t.Run("http", func(t *testing.T) {
  382. ng := newEngine(RestConf{
  383. Host: "localhost",
  384. Port: -1,
  385. })
  386. assert.Error(t, ng.start(router.NewRouter()))
  387. })
  388. t.Run("https", func(t *testing.T) {
  389. ng := newEngine(RestConf{
  390. Host: "localhost",
  391. Port: -1,
  392. CertFile: "foo",
  393. KeyFile: "bar",
  394. })
  395. ng.tlsConfig = &tls.Config{}
  396. assert.Error(t, ng.start(router.NewRouter()))
  397. })
  398. }
  399. type mockedRouter struct {
  400. }
  401. func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
  402. }
  403. func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
  404. return errors.New("foo")
  405. }
  406. func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
  407. }
  408. func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
  409. }