engine_test.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. package rest
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "github.com/wuntsong-org/go-zero-plus/rest/httpx"
  8. "net/http"
  9. "net/http/httptest"
  10. "os"
  11. "sync/atomic"
  12. "testing"
  13. "time"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/wuntsong-org/go-zero-plus/core/conf"
  16. "github.com/wuntsong-org/go-zero-plus/core/fs"
  17. "github.com/wuntsong-org/go-zero-plus/core/logx"
  18. "github.com/wuntsong-org/go-zero-plus/rest/router"
  19. )
  20. const (
  21. priKey = `-----BEGIN RSA PRIVATE KEY-----
  22. MIICXQIBAAKBgQC4TJk3onpqb2RYE3wwt23J9SHLFstHGSkUYFLe+nl1dEKHbD+/
  23. Zt95L757J3xGTrwoTc7KCTxbrgn+stn0w52BNjj/kIE2ko4lbh/v8Fl14AyVR9ms
  24. fKtKOnhe5FCT72mdtApr+qvzcC3q9hfXwkyQU32pv7q5UimZ205iKSBmgQIDAQAB
  25. AoGAM5mWqGIAXj5z3MkP01/4CDxuyrrGDVD5FHBno3CDgyQa4Gmpa4B0/ywj671B
  26. aTnwKmSmiiCN2qleuQYASixes2zY5fgTzt+7KNkl9JHsy7i606eH2eCKzsUa/s6u
  27. WD8V3w/hGCQ9zYI18ihwyXlGHIgcRz/eeRh+nWcWVJzGOPUCQQD5nr6It/1yHb1p
  28. C6l4fC4xXF19l4KxJjGu1xv/sOpSx0pOqBDEX3Mh//FU954392rUWDXV1/I65BPt
  29. TLphdsu3AkEAvQJ2Qay/lffFj9FaUrvXuftJZ/Ypn0FpaSiUh3Ak3obBT6UvSZS0
  30. bcYdCJCNHDtBOsWHnIN1x+BcWAPrdU7PhwJBAIQ0dUlH2S3VXnoCOTGc44I1Hzbj
  31. Rc65IdsuBqA3fQN2lX5vOOIog3vgaFrOArg1jBkG1wx5IMvb/EnUN2pjVqUCQCza
  32. KLXtCInOAlPemlCHwumfeAvznmzsWNdbieOZ+SXVVIpR6KbNYwOpv7oIk3Pfm9sW
  33. hNffWlPUKhW42Gc+DIECQQDmk20YgBXwXWRM5DRPbhisIV088N5Z58K9DtFWkZsd
  34. OBDT3dFcgZONtlmR1MqZO0pTh30lA4qovYj3Bx7A8i36
  35. -----END RSA PRIVATE KEY-----`
  36. )
  37. func TestNewEngine(t *testing.T) {
  38. priKeyfile, err := fs.TempFilenameWithText(priKey)
  39. assert.Nil(t, err)
  40. defer os.Remove(priKeyfile)
  41. yamls := []string{
  42. `Name: foo
  43. Host: localhost
  44. Port: 0
  45. Middlewares:
  46. Log: false
  47. `,
  48. `Name: foo
  49. Host: localhost
  50. Port: 0
  51. CpuThreshold: 500
  52. Middlewares:
  53. Log: false
  54. `,
  55. `Name: foo
  56. Host: localhost
  57. Port: 0
  58. CpuThreshold: 500
  59. Verbose: true
  60. `,
  61. }
  62. routes := []featuredRoutes{
  63. {
  64. jwt: jwtSetting{},
  65. signature: signatureSetting{},
  66. routes: []Route{{
  67. Method: http.MethodGet,
  68. Path: "/",
  69. Handler: func(w http.ResponseWriter, r *http.Request) {},
  70. }},
  71. timeout: time.Minute,
  72. },
  73. {
  74. priority: true,
  75. jwt: jwtSetting{},
  76. signature: signatureSetting{},
  77. routes: []Route{{
  78. Method: http.MethodGet,
  79. Path: "/",
  80. Handler: func(w http.ResponseWriter, r *http.Request) {},
  81. }},
  82. timeout: time.Second,
  83. },
  84. {
  85. priority: true,
  86. jwt: jwtSetting{
  87. enabled: true,
  88. },
  89. signature: signatureSetting{},
  90. routes: []Route{{
  91. Method: http.MethodGet,
  92. Path: "/",
  93. Handler: func(w http.ResponseWriter, r *http.Request) {},
  94. }},
  95. },
  96. {
  97. priority: true,
  98. jwt: jwtSetting{
  99. enabled: true,
  100. prevSecret: "thesecret",
  101. },
  102. signature: signatureSetting{},
  103. routes: []Route{{
  104. Method: http.MethodGet,
  105. Path: "/",
  106. Handler: func(w http.ResponseWriter, r *http.Request) {},
  107. }},
  108. },
  109. {
  110. priority: true,
  111. jwt: jwtSetting{
  112. enabled: true,
  113. },
  114. signature: signatureSetting{},
  115. routes: []Route{{
  116. Method: http.MethodGet,
  117. Path: "/",
  118. Handler: func(w http.ResponseWriter, r *http.Request) {},
  119. }},
  120. },
  121. {
  122. priority: true,
  123. jwt: jwtSetting{
  124. enabled: true,
  125. },
  126. signature: signatureSetting{
  127. enabled: true,
  128. },
  129. routes: []Route{{
  130. Method: http.MethodGet,
  131. Path: "/",
  132. Handler: func(w http.ResponseWriter, r *http.Request) {},
  133. }},
  134. },
  135. {
  136. priority: true,
  137. jwt: jwtSetting{
  138. enabled: true,
  139. },
  140. signature: signatureSetting{
  141. enabled: true,
  142. SignatureConf: SignatureConf{
  143. Strict: true,
  144. },
  145. },
  146. routes: []Route{{
  147. Method: http.MethodGet,
  148. Path: "/",
  149. Handler: func(w http.ResponseWriter, r *http.Request) {},
  150. }},
  151. },
  152. {
  153. priority: true,
  154. jwt: jwtSetting{
  155. enabled: true,
  156. },
  157. signature: signatureSetting{
  158. enabled: true,
  159. SignatureConf: SignatureConf{
  160. Strict: true,
  161. PrivateKeys: []PrivateKeyConf{
  162. {
  163. Fingerprint: "a",
  164. KeyFile: "b",
  165. },
  166. },
  167. },
  168. },
  169. routes: []Route{{
  170. Method: http.MethodGet,
  171. Path: "/",
  172. Handler: func(w http.ResponseWriter, r *http.Request) {},
  173. }},
  174. },
  175. {
  176. priority: true,
  177. jwt: jwtSetting{
  178. enabled: true,
  179. },
  180. signature: signatureSetting{
  181. enabled: true,
  182. SignatureConf: SignatureConf{
  183. Strict: true,
  184. PrivateKeys: []PrivateKeyConf{
  185. {
  186. Fingerprint: "a",
  187. KeyFile: priKeyfile,
  188. },
  189. },
  190. },
  191. },
  192. routes: []Route{{
  193. Method: http.MethodGet,
  194. Path: "/",
  195. Handler: func(w http.ResponseWriter, r *http.Request) {},
  196. }},
  197. },
  198. }
  199. var index int32
  200. for _, yaml := range yamls {
  201. yaml := yaml
  202. for _, route := range routes {
  203. route := route
  204. t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
  205. var cnf RestConf
  206. assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
  207. ng := newEngine(cnf)
  208. if atomic.AddInt32(&index, 1)%2 == 0 {
  209. ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
  210. next http.Handler, strict bool, code int) {
  211. })
  212. }
  213. ng.addRoutes(route)
  214. ng.use(func(next http.HandlerFunc) http.HandlerFunc {
  215. return func(w http.ResponseWriter, r *http.Request) {
  216. next.ServeHTTP(w, r)
  217. }
  218. })
  219. assert.NotNil(t, ng.start(nil, mockedRouter{}, func(svr *http.Server) {
  220. }))
  221. timeout := time.Second * 3
  222. if route.timeout > timeout {
  223. timeout = route.timeout
  224. }
  225. assert.Equal(t, timeout, ng.timeout)
  226. })
  227. }
  228. }
  229. }
  230. func TestEngine_checkedTimeout(t *testing.T) {
  231. tests := []struct {
  232. name string
  233. timeout time.Duration
  234. expect time.Duration
  235. }{
  236. {
  237. name: "not set",
  238. expect: time.Second,
  239. },
  240. {
  241. name: "less",
  242. timeout: time.Millisecond * 500,
  243. expect: time.Millisecond * 500,
  244. },
  245. {
  246. name: "equal",
  247. timeout: time.Second,
  248. expect: time.Second,
  249. },
  250. {
  251. name: "more",
  252. timeout: time.Millisecond * 1500,
  253. expect: time.Millisecond * 1500,
  254. },
  255. }
  256. ng := newEngine(RestConf{
  257. Timeout: 1000,
  258. })
  259. for _, test := range tests {
  260. assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout))
  261. }
  262. }
  263. func TestEngine_checkedMaxBytes(t *testing.T) {
  264. tests := []struct {
  265. name string
  266. maxBytes int64
  267. expect int64
  268. }{
  269. {
  270. name: "not set",
  271. expect: 1000,
  272. },
  273. {
  274. name: "less",
  275. maxBytes: 500,
  276. expect: 500,
  277. },
  278. {
  279. name: "equal",
  280. maxBytes: 1000,
  281. expect: 1000,
  282. },
  283. {
  284. name: "more",
  285. maxBytes: 1500,
  286. expect: 1500,
  287. },
  288. }
  289. ng := newEngine(RestConf{
  290. MaxBytes: 1000,
  291. })
  292. for _, test := range tests {
  293. assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
  294. }
  295. }
  296. func TestEngine_notFoundHandler(t *testing.T) {
  297. logx.Disable()
  298. ng := newEngine(RestConf{})
  299. ts := httptest.NewServer(ng.notFoundHandler(nil))
  300. defer ts.Close()
  301. client := ts.Client()
  302. err := func(_ context.Context) error {
  303. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  304. assert.Nil(t, err)
  305. res, err := client.Do(req)
  306. assert.Nil(t, err)
  307. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  308. return res.Body.Close()
  309. }(context.Background())
  310. assert.Nil(t, err)
  311. }
  312. func TestEngine_notFoundHandlerNotNil(t *testing.T) {
  313. logx.Disable()
  314. ng := newEngine(RestConf{})
  315. var called int32
  316. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  317. atomic.AddInt32(&called, 1)
  318. })))
  319. defer ts.Close()
  320. client := ts.Client()
  321. err := func(_ context.Context) error {
  322. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  323. assert.Nil(t, err)
  324. res, err := client.Do(req)
  325. assert.Nil(t, err)
  326. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  327. return res.Body.Close()
  328. }(context.Background())
  329. assert.Nil(t, err)
  330. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  331. }
  332. func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
  333. logx.Disable()
  334. ng := newEngine(RestConf{})
  335. var called int32
  336. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  337. atomic.AddInt32(&called, 1)
  338. w.WriteHeader(http.StatusExpectationFailed)
  339. })))
  340. defer ts.Close()
  341. client := ts.Client()
  342. err := func(_ context.Context) error {
  343. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  344. assert.Nil(t, err)
  345. res, err := client.Do(req)
  346. assert.Nil(t, err)
  347. assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
  348. return res.Body.Close()
  349. }(context.Background())
  350. assert.Nil(t, err)
  351. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  352. }
  353. func TestEngine_withTimeout(t *testing.T) {
  354. logx.Disable()
  355. tests := []struct {
  356. name string
  357. timeout int64
  358. }{
  359. {
  360. name: "not set",
  361. },
  362. {
  363. name: "set",
  364. timeout: 1000,
  365. },
  366. }
  367. for _, test := range tests {
  368. test := test
  369. t.Run(test.name, func(t *testing.T) {
  370. ng := newEngine(RestConf{Timeout: test.timeout})
  371. svr := &http.Server{}
  372. ng.withTimeout()(svr)
  373. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
  374. assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
  375. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
  376. assert.Equal(t, time.Duration(0), svr.IdleTimeout)
  377. })
  378. }
  379. }
  380. func TestEngine_start(t *testing.T) {
  381. logx.Disable()
  382. t.Run("http", func(t *testing.T) {
  383. ng := newEngine(RestConf{
  384. Host: "localhost",
  385. Port: -1,
  386. })
  387. assert.Error(t, ng.start(nil, router.NewRouter()))
  388. })
  389. t.Run("https", func(t *testing.T) {
  390. ng := newEngine(RestConf{
  391. Host: "localhost",
  392. Port: -1,
  393. CertFile: "foo",
  394. KeyFile: "bar",
  395. })
  396. ng.tlsConfig = &tls.Config{}
  397. assert.Error(t, ng.start(nil, router.NewRouter()))
  398. })
  399. }
  400. type mockedRouter struct {
  401. }
  402. func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
  403. }
  404. func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
  405. return errors.New("foo")
  406. }
  407. func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
  408. }
  409. func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
  410. }
  411. func (m mockedRouter) SetOptionsHandler(_ http.Handler) {
  412. }
  413. func (m mockedRouter) SetMiddleware(_ httpx.MiddlewareFunc) {
  414. }