server_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net/http"
  8. "net/http/httptest"
  9. "os"
  10. "strings"
  11. "sync/atomic"
  12. "testing"
  13. "time"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/zeromicro/go-zero/core/conf"
  16. "github.com/zeromicro/go-zero/core/logx"
  17. "github.com/zeromicro/go-zero/rest/chain"
  18. "github.com/zeromicro/go-zero/rest/httpx"
  19. "github.com/zeromicro/go-zero/rest/internal/cors"
  20. "github.com/zeromicro/go-zero/rest/router"
  21. )
  22. func TestNewServer(t *testing.T) {
  23. writer := logx.Reset()
  24. defer logx.SetWriter(writer)
  25. logx.SetWriter(logx.NewWriter(ioutil.Discard))
  26. const configYaml = `
  27. Name: foo
  28. Port: 54321
  29. `
  30. var cnf RestConf
  31. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  32. tests := []struct {
  33. c RestConf
  34. opts []RunOption
  35. fail bool
  36. }{
  37. {
  38. c: RestConf{},
  39. opts: []RunOption{WithRouter(mockedRouter{}), WithCors()},
  40. },
  41. {
  42. c: cnf,
  43. opts: []RunOption{WithRouter(mockedRouter{})},
  44. },
  45. {
  46. c: cnf,
  47. opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)},
  48. },
  49. {
  50. c: cnf,
  51. opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})},
  52. },
  53. {
  54. c: cnf,
  55. opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})},
  56. },
  57. {
  58. c: cnf,
  59. opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})},
  60. },
  61. }
  62. for _, test := range tests {
  63. var svr *Server
  64. var err error
  65. if test.fail {
  66. _, err = NewServer(test.c, test.opts...)
  67. assert.NotNil(t, err)
  68. continue
  69. } else {
  70. svr = MustNewServer(test.c, test.opts...)
  71. }
  72. svr.Use(ToMiddleware(func(next http.Handler) http.Handler {
  73. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  74. next.ServeHTTP(w, r)
  75. })
  76. }))
  77. svr.AddRoute(Route{
  78. Method: http.MethodGet,
  79. Path: "/",
  80. Handler: nil,
  81. }, WithJwt("thesecret"), WithSignature(SignatureConf{}),
  82. WithJwtTransition("preivous", "thenewone"))
  83. func() {
  84. defer func() {
  85. p := recover()
  86. switch v := p.(type) {
  87. case error:
  88. assert.Equal(t, "foo", v.Error())
  89. default:
  90. t.Fail()
  91. }
  92. }()
  93. svr.Start()
  94. svr.Stop()
  95. }()
  96. }
  97. }
  98. func TestWithMaxBytes(t *testing.T) {
  99. const maxBytes = 1000
  100. var fr featuredRoutes
  101. WithMaxBytes(maxBytes)(&fr)
  102. assert.Equal(t, int64(maxBytes), fr.maxBytes)
  103. }
  104. func TestWithMiddleware(t *testing.T) {
  105. m := make(map[string]string)
  106. rt := router.NewRouter()
  107. handler := func(w http.ResponseWriter, r *http.Request) {
  108. var v struct {
  109. Nickname string `form:"nickname"`
  110. Zipcode int64 `form:"zipcode"`
  111. }
  112. err := httpx.Parse(r, &v)
  113. assert.Nil(t, err)
  114. _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode))
  115. assert.Nil(t, err)
  116. }
  117. rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc {
  118. return func(w http.ResponseWriter, r *http.Request) {
  119. var v struct {
  120. Name string `path:"name"`
  121. Year string `path:"year"`
  122. }
  123. assert.Nil(t, httpx.ParsePath(r, &v))
  124. m[v.Name] = v.Year
  125. next.ServeHTTP(w, r)
  126. }
  127. }, Route{
  128. Method: http.MethodGet,
  129. Path: "/first/:name/:year",
  130. Handler: handler,
  131. }, Route{
  132. Method: http.MethodGet,
  133. Path: "/second/:name/:year",
  134. Handler: handler,
  135. })
  136. urls := []string{
  137. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  138. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  139. }
  140. for _, route := range rs {
  141. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  142. }
  143. for _, url := range urls {
  144. r, err := http.NewRequest(http.MethodGet, url, nil)
  145. assert.Nil(t, err)
  146. rr := httptest.NewRecorder()
  147. rt.ServeHTTP(rr, r)
  148. assert.Equal(t, "whatever:200000", rr.Body.String())
  149. }
  150. assert.EqualValues(t, map[string]string{
  151. "kevin": "2017",
  152. "wan": "2020",
  153. }, m)
  154. }
  155. func TestMultiMiddlewares(t *testing.T) {
  156. m := make(map[string]string)
  157. rt := router.NewRouter()
  158. handler := func(w http.ResponseWriter, r *http.Request) {
  159. var v struct {
  160. Nickname string `form:"nickname"`
  161. Zipcode int64 `form:"zipcode"`
  162. }
  163. err := httpx.Parse(r, &v)
  164. assert.Nil(t, err)
  165. _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
  166. assert.Nil(t, err)
  167. }
  168. rs := WithMiddlewares([]Middleware{
  169. func(next http.HandlerFunc) http.HandlerFunc {
  170. return func(w http.ResponseWriter, r *http.Request) {
  171. var v struct {
  172. Name string `path:"name"`
  173. Year string `path:"year"`
  174. }
  175. assert.Nil(t, httpx.ParsePath(r, &v))
  176. m[v.Name] = v.Year
  177. next.ServeHTTP(w, r)
  178. }
  179. },
  180. func(next http.HandlerFunc) http.HandlerFunc {
  181. return func(w http.ResponseWriter, r *http.Request) {
  182. var v struct {
  183. Name string `form:"nickname"`
  184. Zipcode string `form:"zipcode"`
  185. }
  186. assert.Nil(t, httpx.ParseForm(r, &v))
  187. assert.NotEmpty(t, m)
  188. m[v.Name] = v.Zipcode + v.Zipcode
  189. next.ServeHTTP(w, r)
  190. }
  191. },
  192. ToMiddleware(func(next http.Handler) http.Handler {
  193. return next
  194. }),
  195. }, Route{
  196. Method: http.MethodGet,
  197. Path: "/first/:name/:year",
  198. Handler: handler,
  199. }, Route{
  200. Method: http.MethodGet,
  201. Path: "/second/:name/:year",
  202. Handler: handler,
  203. })
  204. urls := []string{
  205. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  206. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  207. }
  208. for _, route := range rs {
  209. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  210. }
  211. for _, url := range urls {
  212. r, err := http.NewRequest(http.MethodGet, url, nil)
  213. assert.Nil(t, err)
  214. rr := httptest.NewRecorder()
  215. rt.ServeHTTP(rr, r)
  216. assert.Equal(t, "whatever:200000200000", rr.Body.String())
  217. }
  218. assert.EqualValues(t, map[string]string{
  219. "kevin": "2017",
  220. "wan": "2020",
  221. "whatever": "200000200000",
  222. }, m)
  223. }
  224. func TestWithPrefix(t *testing.T) {
  225. fr := featuredRoutes{
  226. routes: []Route{
  227. {
  228. Path: "/hello",
  229. },
  230. {
  231. Path: "/world",
  232. },
  233. },
  234. }
  235. WithPrefix("/api")(&fr)
  236. var vals []string
  237. for _, r := range fr.routes {
  238. vals = append(vals, r.Path)
  239. }
  240. assert.EqualValues(t, []string{"/api/hello", "/api/world"}, vals)
  241. }
  242. func TestWithPriority(t *testing.T) {
  243. var fr featuredRoutes
  244. WithPriority()(&fr)
  245. assert.True(t, fr.priority)
  246. }
  247. func TestWithTimeout(t *testing.T) {
  248. var fr featuredRoutes
  249. WithTimeout(time.Hour)(&fr)
  250. assert.Equal(t, time.Hour, fr.timeout)
  251. }
  252. func TestWithTLSConfig(t *testing.T) {
  253. const configYaml = `
  254. Name: foo
  255. Port: 54321
  256. `
  257. var cnf RestConf
  258. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  259. testConfig := &tls.Config{
  260. CipherSuites: []uint16{
  261. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  262. },
  263. }
  264. testCases := []struct {
  265. c RestConf
  266. opts []RunOption
  267. res *tls.Config
  268. }{
  269. {
  270. c: cnf,
  271. opts: []RunOption{WithTLSConfig(testConfig)},
  272. res: testConfig,
  273. },
  274. {
  275. c: cnf,
  276. opts: []RunOption{WithUnsignedCallback(nil)},
  277. res: nil,
  278. },
  279. }
  280. for _, testCase := range testCases {
  281. svr, err := NewServer(testCase.c, testCase.opts...)
  282. assert.Nil(t, err)
  283. assert.Equal(t, svr.ngin.tlsConfig, testCase.res)
  284. }
  285. }
  286. func TestWithCors(t *testing.T) {
  287. const configYaml = `
  288. Name: foo
  289. Port: 54321
  290. `
  291. var cnf RestConf
  292. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  293. rt := router.NewRouter()
  294. svr, err := NewServer(cnf, WithRouter(rt))
  295. assert.Nil(t, err)
  296. defer svr.Stop()
  297. opt := WithCors("local")
  298. opt(svr)
  299. }
  300. func TestWithCustomCors(t *testing.T) {
  301. const configYaml = `
  302. Name: foo
  303. Port: 54321
  304. `
  305. var cnf RestConf
  306. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  307. rt := router.NewRouter()
  308. svr, err := NewServer(cnf, WithRouter(rt))
  309. assert.Nil(t, err)
  310. opt := WithCustomCors(func(header http.Header) {
  311. header.Set("foo", "bar")
  312. }, func(w http.ResponseWriter) {
  313. w.WriteHeader(http.StatusOK)
  314. }, "local")
  315. opt(svr)
  316. }
  317. func TestServer_PrintRoutes(t *testing.T) {
  318. const (
  319. configYaml = `
  320. Name: foo
  321. Port: 54321
  322. `
  323. expect = `Routes:
  324. GET /bar
  325. GET /foo
  326. GET /foo/:bar
  327. GET /foo/:bar/baz
  328. `
  329. )
  330. var cnf RestConf
  331. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  332. svr, err := NewServer(cnf)
  333. assert.Nil(t, err)
  334. svr.AddRoutes([]Route{
  335. {
  336. Method: http.MethodGet,
  337. Path: "/foo",
  338. Handler: http.NotFound,
  339. },
  340. {
  341. Method: http.MethodGet,
  342. Path: "/bar",
  343. Handler: http.NotFound,
  344. },
  345. {
  346. Method: http.MethodGet,
  347. Path: "/foo/:bar",
  348. Handler: http.NotFound,
  349. },
  350. {
  351. Method: http.MethodGet,
  352. Path: "/foo/:bar/baz",
  353. Handler: http.NotFound,
  354. },
  355. })
  356. old := os.Stdout
  357. r, w, err := os.Pipe()
  358. assert.Nil(t, err)
  359. os.Stdout = w
  360. defer func() {
  361. os.Stdout = old
  362. }()
  363. svr.PrintRoutes()
  364. ch := make(chan string)
  365. go func() {
  366. var buf strings.Builder
  367. io.Copy(&buf, r)
  368. ch <- buf.String()
  369. }()
  370. w.Close()
  371. out := <-ch
  372. assert.Equal(t, expect, out)
  373. }
  374. func TestServer_Routes(t *testing.T) {
  375. const (
  376. configYaml = `
  377. Name: foo
  378. Port: 54321
  379. `
  380. expect = `GET /foo GET /bar GET /foo/:bar GET /foo/:bar/baz`
  381. )
  382. var cnf RestConf
  383. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  384. svr, err := NewServer(cnf)
  385. assert.Nil(t, err)
  386. svr.AddRoutes([]Route{
  387. {
  388. Method: http.MethodGet,
  389. Path: "/foo",
  390. Handler: http.NotFound,
  391. },
  392. {
  393. Method: http.MethodGet,
  394. Path: "/bar",
  395. Handler: http.NotFound,
  396. },
  397. {
  398. Method: http.MethodGet,
  399. Path: "/foo/:bar",
  400. Handler: http.NotFound,
  401. },
  402. {
  403. Method: http.MethodGet,
  404. Path: "/foo/:bar/baz",
  405. Handler: http.NotFound,
  406. },
  407. })
  408. routes := svr.Routes()
  409. var buf strings.Builder
  410. for i := 0; i < len(routes); i++ {
  411. buf.WriteString(routes[i].Method)
  412. buf.WriteString(" ")
  413. buf.WriteString(routes[i].Path)
  414. buf.WriteString(" ")
  415. }
  416. assert.Equal(t, expect, strings.Trim(buf.String(), " "))
  417. }
  418. func TestHandleError(t *testing.T) {
  419. assert.NotPanics(t, func() {
  420. handleError(nil)
  421. handleError(http.ErrServerClosed)
  422. })
  423. }
  424. func TestValidateSecret(t *testing.T) {
  425. assert.Panics(t, func() {
  426. validateSecret("short")
  427. })
  428. }
  429. func TestServer_WithChain(t *testing.T) {
  430. var called int32
  431. middleware1 := func() func(http.Handler) http.Handler {
  432. return func(next http.Handler) http.Handler {
  433. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  434. atomic.AddInt32(&called, 1)
  435. next.ServeHTTP(w, r)
  436. atomic.AddInt32(&called, 1)
  437. })
  438. }
  439. }
  440. middleware2 := func() func(http.Handler) http.Handler {
  441. return func(next http.Handler) http.Handler {
  442. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  443. atomic.AddInt32(&called, 1)
  444. next.ServeHTTP(w, r)
  445. atomic.AddInt32(&called, 1)
  446. })
  447. }
  448. }
  449. server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2())))
  450. server.AddRoutes(
  451. []Route{
  452. {
  453. Method: http.MethodGet,
  454. Path: "/",
  455. Handler: func(_ http.ResponseWriter, _ *http.Request) {
  456. atomic.AddInt32(&called, 1)
  457. },
  458. },
  459. },
  460. )
  461. rt := router.NewRouter()
  462. assert.Nil(t, server.ngin.bindRoutes(rt))
  463. req, err := http.NewRequest(http.MethodGet, "/", nil)
  464. assert.Nil(t, err)
  465. rt.ServeHTTP(httptest.NewRecorder(), req)
  466. assert.Equal(t, int32(5), atomic.LoadInt32(&called))
  467. }
  468. func TestServer_WithCors(t *testing.T) {
  469. var called int32
  470. middleware := func(next http.Handler) http.Handler {
  471. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  472. atomic.AddInt32(&called, 1)
  473. next.ServeHTTP(w, r)
  474. })
  475. }
  476. r := router.NewRouter()
  477. assert.Nil(t, r.Handle(http.MethodOptions, "/", middleware(http.NotFoundHandler())))
  478. cr := &corsRouter{
  479. Router: r,
  480. middleware: cors.Middleware(nil, "*"),
  481. }
  482. req := httptest.NewRequest(http.MethodOptions, "/", nil)
  483. cr.ServeHTTP(httptest.NewRecorder(), req)
  484. assert.Equal(t, int32(0), atomic.LoadInt32(&called))
  485. }