server_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "os"
  9. "strings"
  10. "sync/atomic"
  11. "testing"
  12. "time"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/zeromicro/go-zero/core/conf"
  15. "github.com/zeromicro/go-zero/core/logx"
  16. "github.com/zeromicro/go-zero/rest/chain"
  17. "github.com/zeromicro/go-zero/rest/httpx"
  18. "github.com/zeromicro/go-zero/rest/internal/cors"
  19. "github.com/zeromicro/go-zero/rest/router"
  20. )
  21. func TestNewServer(t *testing.T) {
  22. writer := logx.Reset()
  23. defer logx.SetWriter(writer)
  24. logx.SetWriter(logx.NewWriter(io.Discard))
  25. const configYaml = `
  26. Name: foo
  27. Port: 54321
  28. `
  29. var cnf RestConf
  30. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  31. tests := []struct {
  32. c RestConf
  33. opts []RunOption
  34. fail bool
  35. }{
  36. {
  37. c: RestConf{},
  38. opts: []RunOption{WithRouter(mockedRouter{}), WithCors()},
  39. },
  40. {
  41. c: cnf,
  42. opts: []RunOption{WithRouter(mockedRouter{})},
  43. },
  44. {
  45. c: cnf,
  46. opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)},
  47. },
  48. {
  49. c: cnf,
  50. opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})},
  51. },
  52. {
  53. c: cnf,
  54. opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})},
  55. },
  56. {
  57. c: cnf,
  58. opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})},
  59. },
  60. }
  61. for _, test := range tests {
  62. var svr *Server
  63. var err error
  64. if test.fail {
  65. _, err = NewServer(test.c, test.opts...)
  66. assert.NotNil(t, err)
  67. continue
  68. } else {
  69. svr = MustNewServer(test.c, test.opts...)
  70. }
  71. svr.Use(ToMiddleware(func(next http.Handler) http.Handler {
  72. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  73. next.ServeHTTP(w, r)
  74. })
  75. }))
  76. svr.AddRoute(Route{
  77. Method: http.MethodGet,
  78. Path: "/",
  79. Handler: nil,
  80. }, WithJwt("thesecret"), WithSignature(SignatureConf{}),
  81. WithJwtTransition("preivous", "thenewone"))
  82. func() {
  83. defer func() {
  84. p := recover()
  85. switch v := p.(type) {
  86. case error:
  87. assert.Equal(t, "foo", v.Error())
  88. default:
  89. t.Fail()
  90. }
  91. }()
  92. svr.Start()
  93. svr.Stop()
  94. }()
  95. }
  96. }
  97. func TestWithMaxBytes(t *testing.T) {
  98. const maxBytes = 1000
  99. var fr featuredRoutes
  100. WithMaxBytes(maxBytes)(&fr)
  101. assert.Equal(t, int64(maxBytes), fr.maxBytes)
  102. }
  103. func TestWithMiddleware(t *testing.T) {
  104. m := make(map[string]string)
  105. rt := router.NewRouter()
  106. handler := func(w http.ResponseWriter, r *http.Request) {
  107. var v struct {
  108. Nickname string `form:"nickname"`
  109. Zipcode int64 `form:"zipcode"`
  110. }
  111. err := httpx.Parse(r, &v)
  112. assert.Nil(t, err)
  113. _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode))
  114. assert.Nil(t, err)
  115. }
  116. rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc {
  117. return func(w http.ResponseWriter, r *http.Request) {
  118. var v struct {
  119. Name string `path:"name"`
  120. Year string `path:"year"`
  121. }
  122. assert.Nil(t, httpx.ParsePath(r, &v))
  123. m[v.Name] = v.Year
  124. next.ServeHTTP(w, r)
  125. }
  126. }, Route{
  127. Method: http.MethodGet,
  128. Path: "/first/:name/:year",
  129. Handler: handler,
  130. }, Route{
  131. Method: http.MethodGet,
  132. Path: "/second/:name/:year",
  133. Handler: handler,
  134. })
  135. urls := []string{
  136. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  137. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  138. }
  139. for _, route := range rs {
  140. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  141. }
  142. for _, url := range urls {
  143. r, err := http.NewRequest(http.MethodGet, url, nil)
  144. assert.Nil(t, err)
  145. rr := httptest.NewRecorder()
  146. rt.ServeHTTP(rr, r)
  147. assert.Equal(t, "whatever:200000", rr.Body.String())
  148. }
  149. assert.EqualValues(t, map[string]string{
  150. "kevin": "2017",
  151. "wan": "2020",
  152. }, m)
  153. }
  154. func TestMultiMiddlewares(t *testing.T) {
  155. m := make(map[string]string)
  156. rt := router.NewRouter()
  157. handler := func(w http.ResponseWriter, r *http.Request) {
  158. var v struct {
  159. Nickname string `form:"nickname"`
  160. Zipcode int64 `form:"zipcode"`
  161. }
  162. err := httpx.Parse(r, &v)
  163. assert.Nil(t, err)
  164. _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
  165. assert.Nil(t, err)
  166. }
  167. rs := WithMiddlewares([]Middleware{
  168. func(next http.HandlerFunc) http.HandlerFunc {
  169. return func(w http.ResponseWriter, r *http.Request) {
  170. var v struct {
  171. Name string `path:"name"`
  172. Year string `path:"year"`
  173. }
  174. assert.Nil(t, httpx.ParsePath(r, &v))
  175. m[v.Name] = v.Year
  176. next.ServeHTTP(w, r)
  177. }
  178. },
  179. func(next http.HandlerFunc) http.HandlerFunc {
  180. return func(w http.ResponseWriter, r *http.Request) {
  181. var v struct {
  182. Name string `form:"nickname"`
  183. Zipcode string `form:"zipcode"`
  184. }
  185. assert.Nil(t, httpx.ParseForm(r, &v))
  186. assert.NotEmpty(t, m)
  187. m[v.Name] = v.Zipcode + v.Zipcode
  188. next.ServeHTTP(w, r)
  189. }
  190. },
  191. ToMiddleware(func(next http.Handler) http.Handler {
  192. return next
  193. }),
  194. }, Route{
  195. Method: http.MethodGet,
  196. Path: "/first/:name/:year",
  197. Handler: handler,
  198. }, Route{
  199. Method: http.MethodGet,
  200. Path: "/second/:name/:year",
  201. Handler: handler,
  202. })
  203. urls := []string{
  204. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  205. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  206. }
  207. for _, route := range rs {
  208. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  209. }
  210. for _, url := range urls {
  211. r, err := http.NewRequest(http.MethodGet, url, nil)
  212. assert.Nil(t, err)
  213. rr := httptest.NewRecorder()
  214. rt.ServeHTTP(rr, r)
  215. assert.Equal(t, "whatever:200000200000", rr.Body.String())
  216. }
  217. assert.EqualValues(t, map[string]string{
  218. "kevin": "2017",
  219. "wan": "2020",
  220. "whatever": "200000200000",
  221. }, m)
  222. }
  223. func TestWithPrefix(t *testing.T) {
  224. fr := featuredRoutes{
  225. routes: []Route{
  226. {
  227. Path: "/hello",
  228. },
  229. {
  230. Path: "/world",
  231. },
  232. },
  233. }
  234. WithPrefix("/api")(&fr)
  235. vals := make([]string, 0, len(fr.routes))
  236. for _, r := range fr.routes {
  237. vals = append(vals, r.Path)
  238. }
  239. assert.EqualValues(t, []string{"/api/hello", "/api/world"}, vals)
  240. }
  241. func TestWithPriority(t *testing.T) {
  242. var fr featuredRoutes
  243. WithPriority()(&fr)
  244. assert.True(t, fr.priority)
  245. }
  246. func TestWithTimeout(t *testing.T) {
  247. var fr featuredRoutes
  248. WithTimeout(time.Hour)(&fr)
  249. assert.Equal(t, time.Hour, fr.timeout)
  250. }
  251. func TestWithTLSConfig(t *testing.T) {
  252. const configYaml = `
  253. Name: foo
  254. Port: 54321
  255. `
  256. var cnf RestConf
  257. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  258. testConfig := &tls.Config{
  259. CipherSuites: []uint16{
  260. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  261. },
  262. }
  263. testCases := []struct {
  264. c RestConf
  265. opts []RunOption
  266. res *tls.Config
  267. }{
  268. {
  269. c: cnf,
  270. opts: []RunOption{WithTLSConfig(testConfig)},
  271. res: testConfig,
  272. },
  273. {
  274. c: cnf,
  275. opts: []RunOption{WithUnsignedCallback(nil)},
  276. res: nil,
  277. },
  278. }
  279. for _, testCase := range testCases {
  280. svr, err := NewServer(testCase.c, testCase.opts...)
  281. assert.Nil(t, err)
  282. assert.Equal(t, svr.ngin.tlsConfig, testCase.res)
  283. }
  284. }
  285. func TestWithCors(t *testing.T) {
  286. const configYaml = `
  287. Name: foo
  288. Port: 54321
  289. `
  290. var cnf RestConf
  291. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  292. rt := router.NewRouter()
  293. svr, err := NewServer(cnf, WithRouter(rt))
  294. assert.Nil(t, err)
  295. defer svr.Stop()
  296. opt := WithCors("local")
  297. opt(svr)
  298. }
  299. func TestWithCustomCors(t *testing.T) {
  300. const configYaml = `
  301. Name: foo
  302. Port: 54321
  303. `
  304. var cnf RestConf
  305. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  306. rt := router.NewRouter()
  307. svr, err := NewServer(cnf, WithRouter(rt))
  308. assert.Nil(t, err)
  309. opt := WithCustomCors(func(header http.Header) {
  310. header.Set("foo", "bar")
  311. }, func(w http.ResponseWriter) {
  312. w.WriteHeader(http.StatusOK)
  313. }, "local")
  314. opt(svr)
  315. }
  316. func TestServer_PrintRoutes(t *testing.T) {
  317. const (
  318. configYaml = `
  319. Name: foo
  320. Port: 54321
  321. `
  322. expect = `Routes:
  323. GET /bar
  324. GET /foo
  325. GET /foo/:bar
  326. GET /foo/:bar/baz
  327. `
  328. )
  329. var cnf RestConf
  330. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  331. svr, err := NewServer(cnf)
  332. assert.Nil(t, err)
  333. svr.AddRoutes([]Route{
  334. {
  335. Method: http.MethodGet,
  336. Path: "/foo",
  337. Handler: http.NotFound,
  338. },
  339. {
  340. Method: http.MethodGet,
  341. Path: "/bar",
  342. Handler: http.NotFound,
  343. },
  344. {
  345. Method: http.MethodGet,
  346. Path: "/foo/:bar",
  347. Handler: http.NotFound,
  348. },
  349. {
  350. Method: http.MethodGet,
  351. Path: "/foo/:bar/baz",
  352. Handler: http.NotFound,
  353. },
  354. })
  355. old := os.Stdout
  356. r, w, err := os.Pipe()
  357. assert.Nil(t, err)
  358. os.Stdout = w
  359. defer func() {
  360. os.Stdout = old
  361. }()
  362. svr.PrintRoutes()
  363. ch := make(chan string)
  364. go func() {
  365. var buf strings.Builder
  366. io.Copy(&buf, r)
  367. ch <- buf.String()
  368. }()
  369. w.Close()
  370. out := <-ch
  371. assert.Equal(t, expect, out)
  372. }
  373. func TestServer_Routes(t *testing.T) {
  374. const (
  375. configYaml = `
  376. Name: foo
  377. Port: 54321
  378. `
  379. expect = `GET /foo GET /bar GET /foo/:bar GET /foo/:bar/baz`
  380. )
  381. var cnf RestConf
  382. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  383. svr, err := NewServer(cnf)
  384. assert.Nil(t, err)
  385. svr.AddRoutes([]Route{
  386. {
  387. Method: http.MethodGet,
  388. Path: "/foo",
  389. Handler: http.NotFound,
  390. },
  391. {
  392. Method: http.MethodGet,
  393. Path: "/bar",
  394. Handler: http.NotFound,
  395. },
  396. {
  397. Method: http.MethodGet,
  398. Path: "/foo/:bar",
  399. Handler: http.NotFound,
  400. },
  401. {
  402. Method: http.MethodGet,
  403. Path: "/foo/:bar/baz",
  404. Handler: http.NotFound,
  405. },
  406. })
  407. routes := svr.Routes()
  408. var buf strings.Builder
  409. for i := 0; i < len(routes); i++ {
  410. buf.WriteString(routes[i].Method)
  411. buf.WriteString(" ")
  412. buf.WriteString(routes[i].Path)
  413. buf.WriteString(" ")
  414. }
  415. assert.Equal(t, expect, strings.Trim(buf.String(), " "))
  416. }
  417. func TestHandleError(t *testing.T) {
  418. assert.NotPanics(t, func() {
  419. handleError(nil)
  420. handleError(http.ErrServerClosed)
  421. })
  422. }
  423. func TestValidateSecret(t *testing.T) {
  424. assert.Panics(t, func() {
  425. validateSecret("short")
  426. })
  427. }
  428. func TestServer_WithChain(t *testing.T) {
  429. var called int32
  430. middleware1 := func() func(http.Handler) http.Handler {
  431. return func(next http.Handler) http.Handler {
  432. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  433. atomic.AddInt32(&called, 1)
  434. next.ServeHTTP(w, r)
  435. atomic.AddInt32(&called, 1)
  436. })
  437. }
  438. }
  439. middleware2 := func() func(http.Handler) http.Handler {
  440. return func(next http.Handler) http.Handler {
  441. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  442. atomic.AddInt32(&called, 1)
  443. next.ServeHTTP(w, r)
  444. atomic.AddInt32(&called, 1)
  445. })
  446. }
  447. }
  448. server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2())))
  449. server.AddRoutes(
  450. []Route{
  451. {
  452. Method: http.MethodGet,
  453. Path: "/",
  454. Handler: func(_ http.ResponseWriter, _ *http.Request) {
  455. atomic.AddInt32(&called, 1)
  456. },
  457. },
  458. },
  459. )
  460. rt := router.NewRouter()
  461. assert.Nil(t, server.ngin.bindRoutes(rt))
  462. req, err := http.NewRequest(http.MethodGet, "/", http.NoBody)
  463. assert.Nil(t, err)
  464. rt.ServeHTTP(httptest.NewRecorder(), req)
  465. assert.Equal(t, int32(5), atomic.LoadInt32(&called))
  466. }
  467. func TestServer_WithCors(t *testing.T) {
  468. var called int32
  469. middleware := func(next http.Handler) http.Handler {
  470. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  471. atomic.AddInt32(&called, 1)
  472. next.ServeHTTP(w, r)
  473. })
  474. }
  475. r := router.NewRouter()
  476. assert.Nil(t, r.Handle(http.MethodOptions, "/", middleware(http.NotFoundHandler())))
  477. cr := &corsRouter{
  478. Router: r,
  479. middleware: cors.Middleware(nil, "*"),
  480. }
  481. req := httptest.NewRequest(http.MethodOptions, "/", http.NoBody)
  482. cr.ServeHTTP(httptest.NewRecorder(), req)
  483. assert.Equal(t, int32(0), atomic.LoadInt32(&called))
  484. }
  485. func TestServer_ServeHTTP(t *testing.T) {
  486. const configYaml = `
  487. Name: foo
  488. Port: 54321
  489. `
  490. var cnf RestConf
  491. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  492. svr, err := NewServer(cnf)
  493. assert.Nil(t, err)
  494. svr.AddRoutes([]Route{
  495. {
  496. Method: http.MethodGet,
  497. Path: "/foo",
  498. Handler: func(writer http.ResponseWriter, request *http.Request) {
  499. _, _ = writer.Write([]byte("succeed"))
  500. writer.WriteHeader(http.StatusOK)
  501. },
  502. },
  503. {
  504. Method: http.MethodGet,
  505. Path: "/bar",
  506. Handler: func(writer http.ResponseWriter, request *http.Request) {
  507. _, _ = writer.Write([]byte("succeed"))
  508. writer.WriteHeader(http.StatusOK)
  509. },
  510. },
  511. {
  512. Method: http.MethodGet,
  513. Path: "/user/:name",
  514. Handler: func(writer http.ResponseWriter, request *http.Request) {
  515. var userInfo struct {
  516. Name string `path:"name"`
  517. }
  518. err := httpx.Parse(request, &userInfo)
  519. if err != nil {
  520. _, _ = writer.Write([]byte("failed"))
  521. writer.WriteHeader(http.StatusBadRequest)
  522. return
  523. }
  524. _, _ = writer.Write([]byte("succeed"))
  525. writer.WriteHeader(http.StatusOK)
  526. },
  527. },
  528. })
  529. testCase := []struct {
  530. name string
  531. path string
  532. code int
  533. }{
  534. {
  535. name: "URI : /foo",
  536. path: "/foo",
  537. code: http.StatusOK,
  538. },
  539. {
  540. name: "URI : /bar",
  541. path: "/bar",
  542. code: http.StatusOK,
  543. },
  544. {
  545. name: "URI : undefined path",
  546. path: "/test",
  547. code: http.StatusNotFound,
  548. },
  549. {
  550. name: "URI : /user/:name",
  551. path: "/user/abc",
  552. code: http.StatusOK,
  553. },
  554. }
  555. for _, test := range testCase {
  556. t.Run(test.name, func(t *testing.T) {
  557. w := httptest.NewRecorder()
  558. req, _ := http.NewRequest("GET", test.path, nil)
  559. svr.ServeHTTP(w, req)
  560. assert.Equal(t, test.code, w.Code)
  561. })
  562. }
  563. }