server_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "os"
  9. "strings"
  10. "sync/atomic"
  11. "testing"
  12. "time"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/zeromicro/go-zero/core/conf"
  15. "github.com/zeromicro/go-zero/core/logx"
  16. "github.com/zeromicro/go-zero/rest/chain"
  17. "github.com/zeromicro/go-zero/rest/httpx"
  18. "github.com/zeromicro/go-zero/rest/internal/cors"
  19. "github.com/zeromicro/go-zero/rest/router"
  20. )
  21. func TestNewServer(t *testing.T) {
  22. writer := logx.Reset()
  23. defer logx.SetWriter(writer)
  24. logx.SetWriter(logx.NewWriter(io.Discard))
  25. const configYaml = `
  26. Name: foo
  27. Host: localhost
  28. Port: 0
  29. `
  30. var cnf RestConf
  31. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  32. tests := []struct {
  33. c RestConf
  34. opts []RunOption
  35. fail bool
  36. }{
  37. {
  38. c: RestConf{},
  39. opts: []RunOption{WithRouter(mockedRouter{}), WithCors()},
  40. },
  41. {
  42. c: cnf,
  43. opts: []RunOption{WithRouter(mockedRouter{})},
  44. },
  45. {
  46. c: cnf,
  47. opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)},
  48. },
  49. {
  50. c: cnf,
  51. opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})},
  52. },
  53. {
  54. c: cnf,
  55. opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})},
  56. },
  57. {
  58. c: cnf,
  59. opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})},
  60. },
  61. }
  62. for _, test := range tests {
  63. var svr *Server
  64. var err error
  65. if test.fail {
  66. _, err = NewServer(test.c, test.opts...)
  67. assert.NotNil(t, err)
  68. continue
  69. } else {
  70. svr = MustNewServer(test.c, test.opts...)
  71. }
  72. svr.Use(ToMiddleware(func(next http.Handler) http.Handler {
  73. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  74. next.ServeHTTP(w, r)
  75. })
  76. }))
  77. svr.AddRoute(Route{
  78. Method: http.MethodGet,
  79. Path: "/",
  80. Handler: nil,
  81. }, WithJwt("thesecret"), WithSignature(SignatureConf{}),
  82. WithJwtTransition("preivous", "thenewone"))
  83. func() {
  84. defer func() {
  85. p := recover()
  86. switch v := p.(type) {
  87. case error:
  88. assert.Equal(t, "foo", v.Error())
  89. default:
  90. t.Fail()
  91. }
  92. }()
  93. svr.Start()
  94. svr.Stop()
  95. }()
  96. func() {
  97. defer func() {
  98. p := recover()
  99. switch v := p.(type) {
  100. case error:
  101. assert.Equal(t, "foo", v.Error())
  102. default:
  103. t.Fail()
  104. }
  105. }()
  106. svr.StartWithOpts(func(svr *http.Server) {
  107. svr.RegisterOnShutdown(func() {})
  108. })
  109. svr.Stop()
  110. }()
  111. }
  112. }
  113. func TestWithMaxBytes(t *testing.T) {
  114. const maxBytes = 1000
  115. var fr featuredRoutes
  116. WithMaxBytes(maxBytes)(&fr)
  117. assert.Equal(t, int64(maxBytes), fr.maxBytes)
  118. }
  119. func TestWithMiddleware(t *testing.T) {
  120. m := make(map[string]string)
  121. rt := router.NewRouter()
  122. handler := func(w http.ResponseWriter, r *http.Request) {
  123. var v struct {
  124. Nickname string `form:"nickname"`
  125. Zipcode int64 `form:"zipcode"`
  126. }
  127. err := httpx.Parse(r, &v)
  128. assert.Nil(t, err)
  129. _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode))
  130. assert.Nil(t, err)
  131. }
  132. rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc {
  133. return func(w http.ResponseWriter, r *http.Request) {
  134. var v struct {
  135. Name string `path:"name"`
  136. Year string `path:"year"`
  137. }
  138. assert.Nil(t, httpx.ParsePath(r, &v))
  139. m[v.Name] = v.Year
  140. next.ServeHTTP(w, r)
  141. }
  142. }, Route{
  143. Method: http.MethodGet,
  144. Path: "/first/:name/:year",
  145. Handler: handler,
  146. }, Route{
  147. Method: http.MethodGet,
  148. Path: "/second/:name/:year",
  149. Handler: handler,
  150. })
  151. urls := []string{
  152. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  153. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  154. }
  155. for _, route := range rs {
  156. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  157. }
  158. for _, url := range urls {
  159. r, err := http.NewRequest(http.MethodGet, url, nil)
  160. assert.Nil(t, err)
  161. rr := httptest.NewRecorder()
  162. rt.ServeHTTP(rr, r)
  163. assert.Equal(t, "whatever:200000", rr.Body.String())
  164. }
  165. assert.EqualValues(t, map[string]string{
  166. "kevin": "2017",
  167. "wan": "2020",
  168. }, m)
  169. }
  170. func TestMultiMiddlewares(t *testing.T) {
  171. m := make(map[string]string)
  172. rt := router.NewRouter()
  173. handler := func(w http.ResponseWriter, r *http.Request) {
  174. var v struct {
  175. Nickname string `form:"nickname"`
  176. Zipcode int64 `form:"zipcode"`
  177. }
  178. err := httpx.Parse(r, &v)
  179. assert.Nil(t, err)
  180. _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
  181. assert.Nil(t, err)
  182. }
  183. rs := WithMiddlewares([]Middleware{
  184. func(next http.HandlerFunc) http.HandlerFunc {
  185. return func(w http.ResponseWriter, r *http.Request) {
  186. var v struct {
  187. Name string `path:"name"`
  188. Year string `path:"year"`
  189. }
  190. assert.Nil(t, httpx.ParsePath(r, &v))
  191. m[v.Name] = v.Year
  192. next.ServeHTTP(w, r)
  193. }
  194. },
  195. func(next http.HandlerFunc) http.HandlerFunc {
  196. return func(w http.ResponseWriter, r *http.Request) {
  197. var v struct {
  198. Name string `form:"nickname"`
  199. Zipcode string `form:"zipcode"`
  200. }
  201. assert.Nil(t, httpx.ParseForm(r, &v))
  202. assert.NotEmpty(t, m)
  203. m[v.Name] = v.Zipcode + v.Zipcode
  204. next.ServeHTTP(w, r)
  205. }
  206. },
  207. ToMiddleware(func(next http.Handler) http.Handler {
  208. return next
  209. }),
  210. }, Route{
  211. Method: http.MethodGet,
  212. Path: "/first/:name/:year",
  213. Handler: handler,
  214. }, Route{
  215. Method: http.MethodGet,
  216. Path: "/second/:name/:year",
  217. Handler: handler,
  218. })
  219. urls := []string{
  220. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  221. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  222. }
  223. for _, route := range rs {
  224. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  225. }
  226. for _, url := range urls {
  227. r, err := http.NewRequest(http.MethodGet, url, nil)
  228. assert.Nil(t, err)
  229. rr := httptest.NewRecorder()
  230. rt.ServeHTTP(rr, r)
  231. assert.Equal(t, "whatever:200000200000", rr.Body.String())
  232. }
  233. assert.EqualValues(t, map[string]string{
  234. "kevin": "2017",
  235. "wan": "2020",
  236. "whatever": "200000200000",
  237. }, m)
  238. }
  239. func TestWithPrefix(t *testing.T) {
  240. fr := featuredRoutes{
  241. routes: []Route{
  242. {
  243. Path: "/hello",
  244. },
  245. {
  246. Path: "/world",
  247. },
  248. },
  249. }
  250. WithPrefix("/api")(&fr)
  251. vals := make([]string, 0, len(fr.routes))
  252. for _, r := range fr.routes {
  253. vals = append(vals, r.Path)
  254. }
  255. assert.EqualValues(t, []string{"/api/hello", "/api/world"}, vals)
  256. }
  257. func TestWithPriority(t *testing.T) {
  258. var fr featuredRoutes
  259. WithPriority()(&fr)
  260. assert.True(t, fr.priority)
  261. }
  262. func TestWithTimeout(t *testing.T) {
  263. var fr featuredRoutes
  264. WithTimeout(time.Hour)(&fr)
  265. assert.Equal(t, time.Hour, fr.timeout)
  266. }
  267. func TestWithTLSConfig(t *testing.T) {
  268. const configYaml = `
  269. Name: foo
  270. Port: 54321
  271. `
  272. var cnf RestConf
  273. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  274. testConfig := &tls.Config{
  275. CipherSuites: []uint16{
  276. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  277. },
  278. }
  279. testCases := []struct {
  280. c RestConf
  281. opts []RunOption
  282. res *tls.Config
  283. }{
  284. {
  285. c: cnf,
  286. opts: []RunOption{WithTLSConfig(testConfig)},
  287. res: testConfig,
  288. },
  289. {
  290. c: cnf,
  291. opts: []RunOption{WithUnsignedCallback(nil)},
  292. res: nil,
  293. },
  294. }
  295. for _, testCase := range testCases {
  296. svr, err := NewServer(testCase.c, testCase.opts...)
  297. assert.Nil(t, err)
  298. assert.Equal(t, svr.ngin.tlsConfig, testCase.res)
  299. }
  300. }
  301. func TestWithCors(t *testing.T) {
  302. const configYaml = `
  303. Name: foo
  304. Port: 54321
  305. `
  306. var cnf RestConf
  307. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  308. rt := router.NewRouter()
  309. svr, err := NewServer(cnf, WithRouter(rt))
  310. assert.Nil(t, err)
  311. defer svr.Stop()
  312. opt := WithCors("local")
  313. opt(svr)
  314. }
  315. func TestWithCustomCors(t *testing.T) {
  316. const configYaml = `
  317. Name: foo
  318. Port: 54321
  319. `
  320. var cnf RestConf
  321. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  322. rt := router.NewRouter()
  323. svr, err := NewServer(cnf, WithRouter(rt))
  324. assert.Nil(t, err)
  325. opt := WithCustomCors(func(header http.Header) {
  326. header.Set("foo", "bar")
  327. }, func(w http.ResponseWriter) {
  328. w.WriteHeader(http.StatusOK)
  329. }, "local")
  330. opt(svr)
  331. }
  332. func TestServer_PrintRoutes(t *testing.T) {
  333. const (
  334. configYaml = `
  335. Name: foo
  336. Port: 54321
  337. `
  338. expect = `Routes:
  339. GET /bar
  340. GET /foo
  341. GET /foo/:bar
  342. GET /foo/:bar/baz
  343. `
  344. )
  345. var cnf RestConf
  346. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  347. svr, err := NewServer(cnf)
  348. assert.Nil(t, err)
  349. svr.AddRoutes([]Route{
  350. {
  351. Method: http.MethodGet,
  352. Path: "/foo",
  353. Handler: http.NotFound,
  354. },
  355. {
  356. Method: http.MethodGet,
  357. Path: "/bar",
  358. Handler: http.NotFound,
  359. },
  360. {
  361. Method: http.MethodGet,
  362. Path: "/foo/:bar",
  363. Handler: http.NotFound,
  364. },
  365. {
  366. Method: http.MethodGet,
  367. Path: "/foo/:bar/baz",
  368. Handler: http.NotFound,
  369. },
  370. })
  371. old := os.Stdout
  372. r, w, err := os.Pipe()
  373. assert.Nil(t, err)
  374. os.Stdout = w
  375. defer func() {
  376. os.Stdout = old
  377. }()
  378. svr.PrintRoutes()
  379. ch := make(chan string)
  380. go func() {
  381. var buf strings.Builder
  382. io.Copy(&buf, r)
  383. ch <- buf.String()
  384. }()
  385. w.Close()
  386. out := <-ch
  387. assert.Equal(t, expect, out)
  388. }
  389. func TestServer_Routes(t *testing.T) {
  390. const (
  391. configYaml = `
  392. Name: foo
  393. Port: 54321
  394. `
  395. expect = `GET /foo GET /bar GET /foo/:bar GET /foo/:bar/baz`
  396. )
  397. var cnf RestConf
  398. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  399. svr, err := NewServer(cnf)
  400. assert.Nil(t, err)
  401. svr.AddRoutes([]Route{
  402. {
  403. Method: http.MethodGet,
  404. Path: "/foo",
  405. Handler: http.NotFound,
  406. },
  407. {
  408. Method: http.MethodGet,
  409. Path: "/bar",
  410. Handler: http.NotFound,
  411. },
  412. {
  413. Method: http.MethodGet,
  414. Path: "/foo/:bar",
  415. Handler: http.NotFound,
  416. },
  417. {
  418. Method: http.MethodGet,
  419. Path: "/foo/:bar/baz",
  420. Handler: http.NotFound,
  421. },
  422. })
  423. routes := svr.Routes()
  424. var buf strings.Builder
  425. for i := 0; i < len(routes); i++ {
  426. buf.WriteString(routes[i].Method)
  427. buf.WriteString(" ")
  428. buf.WriteString(routes[i].Path)
  429. buf.WriteString(" ")
  430. }
  431. assert.Equal(t, expect, strings.Trim(buf.String(), " "))
  432. }
  433. func TestHandleError(t *testing.T) {
  434. assert.NotPanics(t, func() {
  435. handleError(nil)
  436. handleError(http.ErrServerClosed)
  437. })
  438. }
  439. func TestValidateSecret(t *testing.T) {
  440. assert.Panics(t, func() {
  441. validateSecret("short")
  442. })
  443. }
  444. func TestServer_WithChain(t *testing.T) {
  445. var called int32
  446. middleware1 := func() func(http.Handler) http.Handler {
  447. return func(next http.Handler) http.Handler {
  448. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  449. atomic.AddInt32(&called, 1)
  450. next.ServeHTTP(w, r)
  451. atomic.AddInt32(&called, 1)
  452. })
  453. }
  454. }
  455. middleware2 := func() func(http.Handler) http.Handler {
  456. return func(next http.Handler) http.Handler {
  457. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  458. atomic.AddInt32(&called, 1)
  459. next.ServeHTTP(w, r)
  460. atomic.AddInt32(&called, 1)
  461. })
  462. }
  463. }
  464. server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2())))
  465. server.AddRoutes(
  466. []Route{
  467. {
  468. Method: http.MethodGet,
  469. Path: "/",
  470. Handler: func(_ http.ResponseWriter, _ *http.Request) {
  471. atomic.AddInt32(&called, 1)
  472. },
  473. },
  474. },
  475. )
  476. rt := router.NewRouter()
  477. assert.Nil(t, server.ngin.bindRoutes(rt))
  478. req, err := http.NewRequest(http.MethodGet, "/", http.NoBody)
  479. assert.Nil(t, err)
  480. rt.ServeHTTP(httptest.NewRecorder(), req)
  481. assert.Equal(t, int32(5), atomic.LoadInt32(&called))
  482. }
  483. func TestServer_WithCors(t *testing.T) {
  484. var called int32
  485. middleware := func(next http.Handler) http.Handler {
  486. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  487. atomic.AddInt32(&called, 1)
  488. next.ServeHTTP(w, r)
  489. })
  490. }
  491. r := router.NewRouter()
  492. assert.Nil(t, r.Handle(http.MethodOptions, "/", middleware(http.NotFoundHandler())))
  493. cr := &corsRouter{
  494. Router: r,
  495. middleware: cors.Middleware(nil, "*"),
  496. }
  497. req := httptest.NewRequest(http.MethodOptions, "/", http.NoBody)
  498. cr.ServeHTTP(httptest.NewRecorder(), req)
  499. assert.Equal(t, int32(0), atomic.LoadInt32(&called))
  500. }
  501. func TestServer_ServeHTTP(t *testing.T) {
  502. const configYaml = `
  503. Name: foo
  504. Port: 54321
  505. `
  506. var cnf RestConf
  507. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  508. svr, err := NewServer(cnf)
  509. assert.Nil(t, err)
  510. svr.AddRoutes([]Route{
  511. {
  512. Method: http.MethodGet,
  513. Path: "/foo",
  514. Handler: func(writer http.ResponseWriter, request *http.Request) {
  515. _, _ = writer.Write([]byte("succeed"))
  516. writer.WriteHeader(http.StatusOK)
  517. },
  518. },
  519. {
  520. Method: http.MethodGet,
  521. Path: "/bar",
  522. Handler: func(writer http.ResponseWriter, request *http.Request) {
  523. _, _ = writer.Write([]byte("succeed"))
  524. writer.WriteHeader(http.StatusOK)
  525. },
  526. },
  527. {
  528. Method: http.MethodGet,
  529. Path: "/user/:name",
  530. Handler: func(writer http.ResponseWriter, request *http.Request) {
  531. var userInfo struct {
  532. Name string `path:"name"`
  533. }
  534. err := httpx.Parse(request, &userInfo)
  535. if err != nil {
  536. _, _ = writer.Write([]byte("failed"))
  537. writer.WriteHeader(http.StatusBadRequest)
  538. return
  539. }
  540. _, _ = writer.Write([]byte("succeed"))
  541. writer.WriteHeader(http.StatusOK)
  542. },
  543. },
  544. })
  545. testCase := []struct {
  546. name string
  547. path string
  548. code int
  549. }{
  550. {
  551. name: "URI : /foo",
  552. path: "/foo",
  553. code: http.StatusOK,
  554. },
  555. {
  556. name: "URI : /bar",
  557. path: "/bar",
  558. code: http.StatusOK,
  559. },
  560. {
  561. name: "URI : undefined path",
  562. path: "/test",
  563. code: http.StatusNotFound,
  564. },
  565. {
  566. name: "URI : /user/:name",
  567. path: "/user/abc",
  568. code: http.StatusOK,
  569. },
  570. }
  571. for _, test := range testCase {
  572. t.Run(test.name, func(t *testing.T) {
  573. w := httptest.NewRecorder()
  574. req, _ := http.NewRequest("GET", test.path, nil)
  575. svr.ServeHTTP(w, req)
  576. assert.Equal(t, test.code, w.Code)
  577. })
  578. }
  579. }