server_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "os"
  9. "strings"
  10. "sync/atomic"
  11. "testing"
  12. "time"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/wuntsong-org/go-zero-plus/core/conf"
  15. "github.com/wuntsong-org/go-zero-plus/core/logx/logtest"
  16. "github.com/wuntsong-org/go-zero-plus/rest/chain"
  17. "github.com/wuntsong-org/go-zero-plus/rest/httpx"
  18. "github.com/wuntsong-org/go-zero-plus/rest/internal/cors"
  19. "github.com/wuntsong-org/go-zero-plus/rest/router"
  20. )
  21. func TestNewServer(t *testing.T) {
  22. logtest.Discard(t)
  23. const configYaml = `
  24. Name: foo
  25. Host: localhost
  26. Port: 0
  27. `
  28. var cnf RestConf
  29. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  30. tests := []struct {
  31. c RestConf
  32. opts []RunOption
  33. fail bool
  34. }{
  35. {
  36. c: RestConf{},
  37. opts: []RunOption{WithRouter(mockedRouter{}), WithCors()},
  38. },
  39. {
  40. c: cnf,
  41. opts: []RunOption{WithRouter(mockedRouter{})},
  42. },
  43. {
  44. c: cnf,
  45. opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)},
  46. },
  47. {
  48. c: cnf,
  49. opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})},
  50. },
  51. {
  52. c: cnf,
  53. opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})},
  54. },
  55. {
  56. c: cnf,
  57. opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})},
  58. },
  59. }
  60. for _, test := range tests {
  61. var svr *Server
  62. var err error
  63. if test.fail {
  64. _, err = NewServer(test.c, test.opts...)
  65. assert.NotNil(t, err)
  66. continue
  67. } else {
  68. svr = MustNewServer(test.c, test.opts...)
  69. }
  70. svr.Use(ToMiddleware(func(next http.Handler) http.Handler {
  71. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  72. next.ServeHTTP(w, r)
  73. })
  74. }))
  75. svr.AddRoute(Route{
  76. Method: http.MethodGet,
  77. Path: "/",
  78. Handler: nil,
  79. }, WithJwt("thesecret"), WithSignature(SignatureConf{}),
  80. WithJwtTransition("preivous", "thenewone"))
  81. func() {
  82. defer func() {
  83. p := recover()
  84. switch v := p.(type) {
  85. case error:
  86. assert.Equal(t, "foo", v.Error())
  87. default:
  88. t.Fail()
  89. }
  90. }()
  91. svr.Start()
  92. svr.Stop()
  93. }()
  94. func() {
  95. defer func() {
  96. p := recover()
  97. switch v := p.(type) {
  98. case error:
  99. assert.Equal(t, "foo", v.Error())
  100. default:
  101. t.Fail()
  102. }
  103. }()
  104. svr.StartWithOpts(func(svr *http.Server) {
  105. svr.RegisterOnShutdown(func() {})
  106. })
  107. svr.Stop()
  108. }()
  109. }
  110. }
  111. func TestWithMaxBytes(t *testing.T) {
  112. const maxBytes = 1000
  113. var fr featuredRoutes
  114. WithMaxBytes(maxBytes)(&fr)
  115. assert.Equal(t, int64(maxBytes), fr.maxBytes)
  116. }
  117. func TestWithMiddleware(t *testing.T) {
  118. m := make(map[string]string)
  119. rt := router.NewRouter()
  120. handler := func(w http.ResponseWriter, r *http.Request) {
  121. var v struct {
  122. Nickname string `form:"nickname"`
  123. Zipcode int64 `form:"zipcode"`
  124. }
  125. err := httpx.Parse(r, &v)
  126. assert.Nil(t, err)
  127. _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode))
  128. assert.Nil(t, err)
  129. }
  130. rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc {
  131. return func(w http.ResponseWriter, r *http.Request) {
  132. var v struct {
  133. Name string `path:"name"`
  134. Year string `path:"year"`
  135. }
  136. assert.Nil(t, httpx.ParsePath(r, &v))
  137. m[v.Name] = v.Year
  138. next.ServeHTTP(w, r)
  139. }
  140. }, Route{
  141. Method: http.MethodGet,
  142. Path: "/first/:name/:year",
  143. Handler: handler,
  144. }, Route{
  145. Method: http.MethodGet,
  146. Path: "/second/:name/:year",
  147. Handler: handler,
  148. })
  149. urls := []string{
  150. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  151. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  152. }
  153. for _, route := range rs {
  154. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  155. }
  156. for _, url := range urls {
  157. r, err := http.NewRequest(http.MethodGet, url, nil)
  158. assert.Nil(t, err)
  159. rr := httptest.NewRecorder()
  160. rt.ServeHTTP(rr, r)
  161. assert.Equal(t, "whatever:200000", rr.Body.String())
  162. }
  163. assert.EqualValues(t, map[string]string{
  164. "kevin": "2017",
  165. "wan": "2020",
  166. }, m)
  167. }
  168. func TestMultiMiddlewares(t *testing.T) {
  169. m := make(map[string]string)
  170. rt := router.NewRouter()
  171. handler := func(w http.ResponseWriter, r *http.Request) {
  172. var v struct {
  173. Nickname string `form:"nickname"`
  174. Zipcode int64 `form:"zipcode"`
  175. }
  176. err := httpx.Parse(r, &v)
  177. assert.Nil(t, err)
  178. _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
  179. assert.Nil(t, err)
  180. }
  181. rs := WithMiddlewares([]Middleware{
  182. func(next http.HandlerFunc) http.HandlerFunc {
  183. return func(w http.ResponseWriter, r *http.Request) {
  184. var v struct {
  185. Name string `path:"name"`
  186. Year string `path:"year"`
  187. }
  188. assert.Nil(t, httpx.ParsePath(r, &v))
  189. m[v.Name] = v.Year
  190. next.ServeHTTP(w, r)
  191. }
  192. },
  193. func(next http.HandlerFunc) http.HandlerFunc {
  194. return func(w http.ResponseWriter, r *http.Request) {
  195. var v struct {
  196. Name string `form:"nickname"`
  197. Zipcode string `form:"zipcode"`
  198. }
  199. assert.Nil(t, httpx.ParseForm(r, &v))
  200. assert.NotEmpty(t, m)
  201. m[v.Name] = v.Zipcode + v.Zipcode
  202. next.ServeHTTP(w, r)
  203. }
  204. },
  205. ToMiddleware(func(next http.Handler) http.Handler {
  206. return next
  207. }),
  208. }, Route{
  209. Method: http.MethodGet,
  210. Path: "/first/:name/:year",
  211. Handler: handler,
  212. }, Route{
  213. Method: http.MethodGet,
  214. Path: "/second/:name/:year",
  215. Handler: handler,
  216. })
  217. urls := []string{
  218. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  219. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  220. }
  221. for _, route := range rs {
  222. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  223. }
  224. for _, url := range urls {
  225. r, err := http.NewRequest(http.MethodGet, url, nil)
  226. assert.Nil(t, err)
  227. rr := httptest.NewRecorder()
  228. rt.ServeHTTP(rr, r)
  229. assert.Equal(t, "whatever:200000200000", rr.Body.String())
  230. }
  231. assert.EqualValues(t, map[string]string{
  232. "kevin": "2017",
  233. "wan": "2020",
  234. "whatever": "200000200000",
  235. }, m)
  236. }
  237. func TestWithPrefix(t *testing.T) {
  238. fr := featuredRoutes{
  239. routes: []Route{
  240. {
  241. Path: "/hello",
  242. },
  243. {
  244. Path: "/world",
  245. },
  246. },
  247. }
  248. WithPrefix("/api")(&fr)
  249. vals := make([]string, 0, len(fr.routes))
  250. for _, r := range fr.routes {
  251. vals = append(vals, r.Path)
  252. }
  253. assert.EqualValues(t, []string{"/api/hello", "/api/world"}, vals)
  254. }
  255. func TestWithPriority(t *testing.T) {
  256. var fr featuredRoutes
  257. WithPriority()(&fr)
  258. assert.True(t, fr.priority)
  259. }
  260. func TestWithTimeout(t *testing.T) {
  261. var fr featuredRoutes
  262. WithTimeout(time.Hour)(&fr)
  263. assert.Equal(t, time.Hour, fr.timeout)
  264. }
  265. func TestWithTLSConfig(t *testing.T) {
  266. const configYaml = `
  267. Name: foo
  268. Port: 54321
  269. `
  270. var cnf RestConf
  271. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  272. testConfig := &tls.Config{
  273. CipherSuites: []uint16{
  274. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  275. },
  276. }
  277. testCases := []struct {
  278. c RestConf
  279. opts []RunOption
  280. res *tls.Config
  281. }{
  282. {
  283. c: cnf,
  284. opts: []RunOption{WithTLSConfig(testConfig)},
  285. res: testConfig,
  286. },
  287. {
  288. c: cnf,
  289. opts: []RunOption{WithUnsignedCallback(nil)},
  290. res: nil,
  291. },
  292. }
  293. for _, testCase := range testCases {
  294. svr, err := NewServer(testCase.c, testCase.opts...)
  295. assert.Nil(t, err)
  296. assert.Equal(t, svr.ngin.tlsConfig, testCase.res)
  297. }
  298. }
  299. func TestWithCors(t *testing.T) {
  300. const configYaml = `
  301. Name: foo
  302. Port: 54321
  303. `
  304. var cnf RestConf
  305. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  306. rt := router.NewRouter()
  307. svr, err := NewServer(cnf, WithRouter(rt))
  308. assert.Nil(t, err)
  309. defer svr.Stop()
  310. opt := WithCors("local")
  311. opt(svr)
  312. }
  313. func TestWithCustomCors(t *testing.T) {
  314. const configYaml = `
  315. Name: foo
  316. Port: 54321
  317. `
  318. var cnf RestConf
  319. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  320. rt := router.NewRouter()
  321. svr, err := NewServer(cnf, WithRouter(rt))
  322. assert.Nil(t, err)
  323. opt := WithCustomCors(func(header http.Header) {
  324. header.Set("foo", "bar")
  325. }, func(w http.ResponseWriter) {
  326. w.WriteHeader(http.StatusOK)
  327. }, "local")
  328. opt(svr)
  329. }
  330. func TestServer_PrintRoutes(t *testing.T) {
  331. const (
  332. configYaml = `
  333. Name: foo
  334. Port: 54321
  335. `
  336. expect = `Routes:
  337. GET /bar
  338. GET /foo
  339. GET /foo/:bar
  340. GET /foo/:bar/baz
  341. `
  342. )
  343. var cnf RestConf
  344. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  345. svr, err := NewServer(cnf)
  346. assert.Nil(t, err)
  347. svr.AddRoutes([]Route{
  348. {
  349. Method: http.MethodGet,
  350. Path: "/foo",
  351. Handler: http.NotFound,
  352. },
  353. {
  354. Method: http.MethodGet,
  355. Path: "/bar",
  356. Handler: http.NotFound,
  357. },
  358. {
  359. Method: http.MethodGet,
  360. Path: "/foo/:bar",
  361. Handler: http.NotFound,
  362. },
  363. {
  364. Method: http.MethodGet,
  365. Path: "/foo/:bar/baz",
  366. Handler: http.NotFound,
  367. },
  368. })
  369. old := os.Stdout
  370. r, w, err := os.Pipe()
  371. assert.Nil(t, err)
  372. os.Stdout = w
  373. defer func() {
  374. os.Stdout = old
  375. }()
  376. svr.PrintRoutes()
  377. ch := make(chan string)
  378. go func() {
  379. var buf strings.Builder
  380. io.Copy(&buf, r)
  381. ch <- buf.String()
  382. }()
  383. w.Close()
  384. out := <-ch
  385. assert.Equal(t, expect, out)
  386. }
  387. func TestServer_Routes(t *testing.T) {
  388. const (
  389. configYaml = `
  390. Name: foo
  391. Port: 54321
  392. `
  393. expect = `GET /foo GET /bar GET /foo/:bar GET /foo/:bar/baz`
  394. )
  395. var cnf RestConf
  396. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  397. svr, err := NewServer(cnf)
  398. assert.Nil(t, err)
  399. svr.AddRoutes([]Route{
  400. {
  401. Method: http.MethodGet,
  402. Path: "/foo",
  403. Handler: http.NotFound,
  404. },
  405. {
  406. Method: http.MethodGet,
  407. Path: "/bar",
  408. Handler: http.NotFound,
  409. },
  410. {
  411. Method: http.MethodGet,
  412. Path: "/foo/:bar",
  413. Handler: http.NotFound,
  414. },
  415. {
  416. Method: http.MethodGet,
  417. Path: "/foo/:bar/baz",
  418. Handler: http.NotFound,
  419. },
  420. })
  421. routes := svr.Routes()
  422. var buf strings.Builder
  423. for i := 0; i < len(routes); i++ {
  424. buf.WriteString(routes[i].Method)
  425. buf.WriteString(" ")
  426. buf.WriteString(routes[i].Path)
  427. buf.WriteString(" ")
  428. }
  429. assert.Equal(t, expect, strings.Trim(buf.String(), " "))
  430. }
  431. func TestHandleError(t *testing.T) {
  432. assert.NotPanics(t, func() {
  433. handleError(nil)
  434. handleError(http.ErrServerClosed)
  435. })
  436. }
  437. func TestValidateSecret(t *testing.T) {
  438. assert.Panics(t, func() {
  439. validateSecret("short")
  440. })
  441. }
  442. func TestServer_WithChain(t *testing.T) {
  443. var called int32
  444. middleware1 := func() func(http.Handler) http.Handler {
  445. return func(next http.Handler) http.Handler {
  446. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  447. atomic.AddInt32(&called, 1)
  448. next.ServeHTTP(w, r)
  449. atomic.AddInt32(&called, 1)
  450. })
  451. }
  452. }
  453. middleware2 := func() func(http.Handler) http.Handler {
  454. return func(next http.Handler) http.Handler {
  455. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  456. atomic.AddInt32(&called, 1)
  457. next.ServeHTTP(w, r)
  458. atomic.AddInt32(&called, 1)
  459. })
  460. }
  461. }
  462. server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2())))
  463. server.AddRoutes(
  464. []Route{
  465. {
  466. Method: http.MethodGet,
  467. Path: "/",
  468. Handler: func(_ http.ResponseWriter, _ *http.Request) {
  469. atomic.AddInt32(&called, 1)
  470. },
  471. },
  472. },
  473. )
  474. rt := router.NewRouter()
  475. assert.Nil(t, server.ngin.bindRoutes(rt))
  476. req, err := http.NewRequest(http.MethodGet, "/", http.NoBody)
  477. assert.Nil(t, err)
  478. rt.ServeHTTP(httptest.NewRecorder(), req)
  479. assert.Equal(t, int32(5), atomic.LoadInt32(&called))
  480. }
  481. func TestServer_WithCors(t *testing.T) {
  482. var called int32
  483. middleware := func(next http.Handler) http.Handler {
  484. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  485. atomic.AddInt32(&called, 1)
  486. next.ServeHTTP(w, r)
  487. })
  488. }
  489. r := router.NewRouter()
  490. assert.Nil(t, r.Handle(http.MethodOptions, "/", middleware(http.NotFoundHandler())))
  491. cr := &corsRouter{
  492. Router: r,
  493. middleware: cors.Middleware(nil, "*"),
  494. }
  495. req := httptest.NewRequest(http.MethodOptions, "/", http.NoBody)
  496. cr.ServeHTTP(httptest.NewRecorder(), req)
  497. assert.Equal(t, int32(0), atomic.LoadInt32(&called))
  498. }
  499. func TestServer_ServeHTTP(t *testing.T) {
  500. const configYaml = `
  501. Name: foo
  502. Port: 54321
  503. `
  504. var cnf RestConf
  505. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  506. svr, err := NewServer(cnf)
  507. assert.Nil(t, err)
  508. svr.AddRoutes([]Route{
  509. {
  510. Method: http.MethodGet,
  511. Path: "/foo",
  512. Handler: func(writer http.ResponseWriter, request *http.Request) {
  513. _, _ = writer.Write([]byte("succeed"))
  514. writer.WriteHeader(http.StatusOK)
  515. },
  516. },
  517. {
  518. Method: http.MethodGet,
  519. Path: "/bar",
  520. Handler: func(writer http.ResponseWriter, request *http.Request) {
  521. _, _ = writer.Write([]byte("succeed"))
  522. writer.WriteHeader(http.StatusOK)
  523. },
  524. },
  525. {
  526. Method: http.MethodGet,
  527. Path: "/user/:name",
  528. Handler: func(writer http.ResponseWriter, request *http.Request) {
  529. var userInfo struct {
  530. Name string `path:"name"`
  531. }
  532. err := httpx.Parse(request, &userInfo)
  533. if err != nil {
  534. _, _ = writer.Write([]byte("failed"))
  535. writer.WriteHeader(http.StatusBadRequest)
  536. return
  537. }
  538. _, _ = writer.Write([]byte("succeed"))
  539. writer.WriteHeader(http.StatusOK)
  540. },
  541. },
  542. })
  543. testCase := []struct {
  544. name string
  545. path string
  546. code int
  547. }{
  548. {
  549. name: "URI : /foo",
  550. path: "/foo",
  551. code: http.StatusOK,
  552. },
  553. {
  554. name: "URI : /bar",
  555. path: "/bar",
  556. code: http.StatusOK,
  557. },
  558. {
  559. name: "URI : undefined path",
  560. path: "/test",
  561. code: http.StatusNotFound,
  562. },
  563. {
  564. name: "URI : /user/:name",
  565. path: "/user/abc",
  566. code: http.StatusOK,
  567. },
  568. }
  569. for _, test := range testCase {
  570. t.Run(test.name, func(t *testing.T) {
  571. w := httptest.NewRecorder()
  572. req, _ := http.NewRequest("GET", test.path, nil)
  573. svr.ServeHTTP(w, req)
  574. assert.Equal(t, test.code, w.Code)
  575. })
  576. }
  577. }