engine_test.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. package rest
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "net/http/httptest"
  8. "os"
  9. "sync/atomic"
  10. "testing"
  11. "time"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/zeromicro/go-zero/core/conf"
  14. "github.com/zeromicro/go-zero/core/fs"
  15. "github.com/zeromicro/go-zero/core/logx"
  16. )
  17. const (
  18. priKey = `-----BEGIN RSA PRIVATE KEY-----
  19. MIICXQIBAAKBgQC4TJk3onpqb2RYE3wwt23J9SHLFstHGSkUYFLe+nl1dEKHbD+/
  20. Zt95L757J3xGTrwoTc7KCTxbrgn+stn0w52BNjj/kIE2ko4lbh/v8Fl14AyVR9ms
  21. fKtKOnhe5FCT72mdtApr+qvzcC3q9hfXwkyQU32pv7q5UimZ205iKSBmgQIDAQAB
  22. AoGAM5mWqGIAXj5z3MkP01/4CDxuyrrGDVD5FHBno3CDgyQa4Gmpa4B0/ywj671B
  23. aTnwKmSmiiCN2qleuQYASixes2zY5fgTzt+7KNkl9JHsy7i606eH2eCKzsUa/s6u
  24. WD8V3w/hGCQ9zYI18ihwyXlGHIgcRz/eeRh+nWcWVJzGOPUCQQD5nr6It/1yHb1p
  25. C6l4fC4xXF19l4KxJjGu1xv/sOpSx0pOqBDEX3Mh//FU954392rUWDXV1/I65BPt
  26. TLphdsu3AkEAvQJ2Qay/lffFj9FaUrvXuftJZ/Ypn0FpaSiUh3Ak3obBT6UvSZS0
  27. bcYdCJCNHDtBOsWHnIN1x+BcWAPrdU7PhwJBAIQ0dUlH2S3VXnoCOTGc44I1Hzbj
  28. Rc65IdsuBqA3fQN2lX5vOOIog3vgaFrOArg1jBkG1wx5IMvb/EnUN2pjVqUCQCza
  29. KLXtCInOAlPemlCHwumfeAvznmzsWNdbieOZ+SXVVIpR6KbNYwOpv7oIk3Pfm9sW
  30. hNffWlPUKhW42Gc+DIECQQDmk20YgBXwXWRM5DRPbhisIV088N5Z58K9DtFWkZsd
  31. OBDT3dFcgZONtlmR1MqZO0pTh30lA4qovYj3Bx7A8i36
  32. -----END RSA PRIVATE KEY-----`
  33. )
  34. func TestNewEngine(t *testing.T) {
  35. priKeyfile, err := fs.TempFilenameWithText(priKey)
  36. assert.Nil(t, err)
  37. defer os.Remove(priKeyfile)
  38. yamls := []string{
  39. `Name: foo
  40. Host: localhost
  41. Port: 0
  42. Middlewares:
  43. Log: false
  44. `,
  45. `Name: foo
  46. Host: localhost
  47. Port: 0
  48. CpuThreshold: 500
  49. Middlewares:
  50. Log: false
  51. `,
  52. `Name: foo
  53. Host: localhost
  54. Port: 0
  55. CpuThreshold: 500
  56. Verbose: true
  57. `,
  58. }
  59. routes := []featuredRoutes{
  60. {
  61. jwt: jwtSetting{},
  62. signature: signatureSetting{},
  63. routes: []Route{{
  64. Method: http.MethodGet,
  65. Path: "/",
  66. Handler: func(w http.ResponseWriter, r *http.Request) {},
  67. }},
  68. timeout: time.Minute,
  69. },
  70. {
  71. priority: true,
  72. jwt: jwtSetting{},
  73. signature: signatureSetting{},
  74. routes: []Route{{
  75. Method: http.MethodGet,
  76. Path: "/",
  77. Handler: func(w http.ResponseWriter, r *http.Request) {},
  78. }},
  79. timeout: time.Second,
  80. },
  81. {
  82. priority: true,
  83. jwt: jwtSetting{
  84. enabled: true,
  85. },
  86. signature: signatureSetting{},
  87. routes: []Route{{
  88. Method: http.MethodGet,
  89. Path: "/",
  90. Handler: func(w http.ResponseWriter, r *http.Request) {},
  91. }},
  92. },
  93. {
  94. priority: true,
  95. jwt: jwtSetting{
  96. enabled: true,
  97. prevSecret: "thesecret",
  98. },
  99. signature: signatureSetting{},
  100. routes: []Route{{
  101. Method: http.MethodGet,
  102. Path: "/",
  103. Handler: func(w http.ResponseWriter, r *http.Request) {},
  104. }},
  105. },
  106. {
  107. priority: true,
  108. jwt: jwtSetting{
  109. enabled: true,
  110. },
  111. signature: signatureSetting{},
  112. routes: []Route{{
  113. Method: http.MethodGet,
  114. Path: "/",
  115. Handler: func(w http.ResponseWriter, r *http.Request) {},
  116. }},
  117. },
  118. {
  119. priority: true,
  120. jwt: jwtSetting{
  121. enabled: true,
  122. },
  123. signature: signatureSetting{
  124. enabled: true,
  125. },
  126. routes: []Route{{
  127. Method: http.MethodGet,
  128. Path: "/",
  129. Handler: func(w http.ResponseWriter, r *http.Request) {},
  130. }},
  131. },
  132. {
  133. priority: true,
  134. jwt: jwtSetting{
  135. enabled: true,
  136. },
  137. signature: signatureSetting{
  138. enabled: true,
  139. SignatureConf: SignatureConf{
  140. Strict: true,
  141. },
  142. },
  143. routes: []Route{{
  144. Method: http.MethodGet,
  145. Path: "/",
  146. Handler: func(w http.ResponseWriter, r *http.Request) {},
  147. }},
  148. },
  149. {
  150. priority: true,
  151. jwt: jwtSetting{
  152. enabled: true,
  153. },
  154. signature: signatureSetting{
  155. enabled: true,
  156. SignatureConf: SignatureConf{
  157. Strict: true,
  158. PrivateKeys: []PrivateKeyConf{
  159. {
  160. Fingerprint: "a",
  161. KeyFile: "b",
  162. },
  163. },
  164. },
  165. },
  166. routes: []Route{{
  167. Method: http.MethodGet,
  168. Path: "/",
  169. Handler: func(w http.ResponseWriter, r *http.Request) {},
  170. }},
  171. },
  172. {
  173. priority: true,
  174. jwt: jwtSetting{
  175. enabled: true,
  176. },
  177. signature: signatureSetting{
  178. enabled: true,
  179. SignatureConf: SignatureConf{
  180. Strict: true,
  181. PrivateKeys: []PrivateKeyConf{
  182. {
  183. Fingerprint: "a",
  184. KeyFile: priKeyfile,
  185. },
  186. },
  187. },
  188. },
  189. routes: []Route{{
  190. Method: http.MethodGet,
  191. Path: "/",
  192. Handler: func(w http.ResponseWriter, r *http.Request) {},
  193. }},
  194. },
  195. }
  196. for _, yaml := range yamls {
  197. yaml := yaml
  198. for _, route := range routes {
  199. route := route
  200. t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
  201. var cnf RestConf
  202. assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
  203. ng := newEngine(cnf)
  204. ng.addRoutes(route)
  205. ng.use(func(next http.HandlerFunc) http.HandlerFunc {
  206. return func(w http.ResponseWriter, r *http.Request) {
  207. next.ServeHTTP(w, r)
  208. }
  209. })
  210. assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
  211. }))
  212. timeout := time.Second * 3
  213. if route.timeout > timeout {
  214. timeout = route.timeout
  215. }
  216. assert.Equal(t, timeout, ng.timeout)
  217. })
  218. }
  219. }
  220. }
  221. func TestEngine_checkedTimeout(t *testing.T) {
  222. tests := []struct {
  223. name string
  224. timeout time.Duration
  225. expect time.Duration
  226. }{
  227. {
  228. name: "not set",
  229. expect: time.Second,
  230. },
  231. {
  232. name: "less",
  233. timeout: time.Millisecond * 500,
  234. expect: time.Millisecond * 500,
  235. },
  236. {
  237. name: "equal",
  238. timeout: time.Second,
  239. expect: time.Second,
  240. },
  241. {
  242. name: "more",
  243. timeout: time.Millisecond * 1500,
  244. expect: time.Millisecond * 1500,
  245. },
  246. }
  247. ng := newEngine(RestConf{
  248. Timeout: 1000,
  249. })
  250. for _, test := range tests {
  251. assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout))
  252. }
  253. }
  254. func TestEngine_checkedMaxBytes(t *testing.T) {
  255. tests := []struct {
  256. name string
  257. maxBytes int64
  258. expect int64
  259. }{
  260. {
  261. name: "not set",
  262. expect: 1000,
  263. },
  264. {
  265. name: "less",
  266. maxBytes: 500,
  267. expect: 500,
  268. },
  269. {
  270. name: "equal",
  271. maxBytes: 1000,
  272. expect: 1000,
  273. },
  274. {
  275. name: "more",
  276. maxBytes: 1500,
  277. expect: 1500,
  278. },
  279. }
  280. ng := newEngine(RestConf{
  281. MaxBytes: 1000,
  282. })
  283. for _, test := range tests {
  284. assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
  285. }
  286. }
  287. func TestEngine_notFoundHandler(t *testing.T) {
  288. logx.Disable()
  289. ng := newEngine(RestConf{})
  290. ts := httptest.NewServer(ng.notFoundHandler(nil))
  291. defer ts.Close()
  292. client := ts.Client()
  293. err := func(_ context.Context) error {
  294. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  295. assert.Nil(t, err)
  296. res, err := client.Do(req)
  297. assert.Nil(t, err)
  298. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  299. return res.Body.Close()
  300. }(context.Background())
  301. assert.Nil(t, err)
  302. }
  303. func TestEngine_notFoundHandlerNotNil(t *testing.T) {
  304. logx.Disable()
  305. ng := newEngine(RestConf{})
  306. var called int32
  307. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  308. atomic.AddInt32(&called, 1)
  309. })))
  310. defer ts.Close()
  311. client := ts.Client()
  312. err := func(_ context.Context) error {
  313. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  314. assert.Nil(t, err)
  315. res, err := client.Do(req)
  316. assert.Nil(t, err)
  317. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  318. return res.Body.Close()
  319. }(context.Background())
  320. assert.Nil(t, err)
  321. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  322. }
  323. func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
  324. logx.Disable()
  325. ng := newEngine(RestConf{})
  326. var called int32
  327. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  328. atomic.AddInt32(&called, 1)
  329. w.WriteHeader(http.StatusExpectationFailed)
  330. })))
  331. defer ts.Close()
  332. client := ts.Client()
  333. err := func(_ context.Context) error {
  334. req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
  335. assert.Nil(t, err)
  336. res, err := client.Do(req)
  337. assert.Nil(t, err)
  338. assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
  339. return res.Body.Close()
  340. }(context.Background())
  341. assert.Nil(t, err)
  342. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  343. }
  344. func TestEngine_withTimeout(t *testing.T) {
  345. logx.Disable()
  346. tests := []struct {
  347. name string
  348. timeout int64
  349. }{
  350. {
  351. name: "not set",
  352. },
  353. {
  354. name: "set",
  355. timeout: 1000,
  356. },
  357. }
  358. for _, test := range tests {
  359. test := test
  360. t.Run(test.name, func(t *testing.T) {
  361. ng := newEngine(RestConf{Timeout: test.timeout})
  362. svr := &http.Server{}
  363. ng.withTimeout()(svr)
  364. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
  365. assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
  366. assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
  367. assert.Equal(t, time.Duration(0), svr.IdleTimeout)
  368. })
  369. }
  370. }
  371. type mockedRouter struct {
  372. }
  373. func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
  374. }
  375. func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
  376. return errors.New("foo")
  377. }
  378. func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
  379. }
  380. func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
  381. }