server_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net/http"
  8. "net/http/httptest"
  9. "os"
  10. "strings"
  11. "sync/atomic"
  12. "testing"
  13. "time"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/zeromicro/go-zero/core/conf"
  16. "github.com/zeromicro/go-zero/core/logx"
  17. "github.com/zeromicro/go-zero/core/service"
  18. "github.com/zeromicro/go-zero/rest/chain"
  19. "github.com/zeromicro/go-zero/rest/httpx"
  20. "github.com/zeromicro/go-zero/rest/router"
  21. )
  22. func TestNewServer(t *testing.T) {
  23. writer := logx.Reset()
  24. defer logx.SetWriter(writer)
  25. logx.SetWriter(logx.NewWriter(ioutil.Discard))
  26. const configYaml = `
  27. Name: foo
  28. Port: 54321
  29. `
  30. var cnf RestConf
  31. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  32. tests := []struct {
  33. c RestConf
  34. opts []RunOption
  35. fail bool
  36. }{
  37. {
  38. c: RestConf{},
  39. opts: []RunOption{WithRouter(mockedRouter{}), WithCors()},
  40. },
  41. {
  42. c: cnf,
  43. opts: []RunOption{WithRouter(mockedRouter{})},
  44. },
  45. {
  46. c: cnf,
  47. opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)},
  48. },
  49. {
  50. c: cnf,
  51. opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})},
  52. },
  53. {
  54. c: cnf,
  55. opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})},
  56. },
  57. {
  58. c: cnf,
  59. opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})},
  60. },
  61. }
  62. for _, test := range tests {
  63. var svr *Server
  64. var err error
  65. if test.fail {
  66. _, err = NewServer(test.c, test.opts...)
  67. assert.NotNil(t, err)
  68. continue
  69. } else {
  70. svr = MustNewServer(test.c, test.opts...)
  71. }
  72. svr.Use(ToMiddleware(func(next http.Handler) http.Handler {
  73. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  74. next.ServeHTTP(w, r)
  75. })
  76. }))
  77. svr.AddRoute(Route{
  78. Method: http.MethodGet,
  79. Path: "/",
  80. Handler: nil,
  81. }, WithJwt("thesecret"), WithSignature(SignatureConf{}),
  82. WithJwtTransition("preivous", "thenewone"))
  83. func() {
  84. defer func() {
  85. p := recover()
  86. switch v := p.(type) {
  87. case error:
  88. assert.Equal(t, "foo", v.Error())
  89. default:
  90. t.Fail()
  91. }
  92. }()
  93. svr.Start()
  94. svr.Stop()
  95. }()
  96. }
  97. }
  98. func TestNewServerError(t *testing.T) {
  99. _, err := NewServer(RestConf{
  100. ServiceConf: service.ServiceConf{
  101. Log: logx.LogConf{
  102. // file mode, no path specified
  103. Mode: "file",
  104. },
  105. },
  106. })
  107. assert.NotNil(t, err)
  108. }
  109. func TestWithMaxBytes(t *testing.T) {
  110. const maxBytes = 1000
  111. var fr featuredRoutes
  112. WithMaxBytes(maxBytes)(&fr)
  113. assert.Equal(t, int64(maxBytes), fr.maxBytes)
  114. }
  115. func TestWithMiddleware(t *testing.T) {
  116. m := make(map[string]string)
  117. rt := router.NewRouter()
  118. handler := func(w http.ResponseWriter, r *http.Request) {
  119. var v struct {
  120. Nickname string `form:"nickname"`
  121. Zipcode int64 `form:"zipcode"`
  122. }
  123. err := httpx.Parse(r, &v)
  124. assert.Nil(t, err)
  125. _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode))
  126. assert.Nil(t, err)
  127. }
  128. rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc {
  129. return func(w http.ResponseWriter, r *http.Request) {
  130. var v struct {
  131. Name string `path:"name"`
  132. Year string `path:"year"`
  133. }
  134. assert.Nil(t, httpx.ParsePath(r, &v))
  135. m[v.Name] = v.Year
  136. next.ServeHTTP(w, r)
  137. }
  138. }, Route{
  139. Method: http.MethodGet,
  140. Path: "/first/:name/:year",
  141. Handler: handler,
  142. }, Route{
  143. Method: http.MethodGet,
  144. Path: "/second/:name/:year",
  145. Handler: handler,
  146. })
  147. urls := []string{
  148. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  149. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  150. }
  151. for _, route := range rs {
  152. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  153. }
  154. for _, url := range urls {
  155. r, err := http.NewRequest(http.MethodGet, url, nil)
  156. assert.Nil(t, err)
  157. rr := httptest.NewRecorder()
  158. rt.ServeHTTP(rr, r)
  159. assert.Equal(t, "whatever:200000", rr.Body.String())
  160. }
  161. assert.EqualValues(t, map[string]string{
  162. "kevin": "2017",
  163. "wan": "2020",
  164. }, m)
  165. }
  166. func TestMultiMiddlewares(t *testing.T) {
  167. m := make(map[string]string)
  168. rt := router.NewRouter()
  169. handler := func(w http.ResponseWriter, r *http.Request) {
  170. var v struct {
  171. Nickname string `form:"nickname"`
  172. Zipcode int64 `form:"zipcode"`
  173. }
  174. err := httpx.Parse(r, &v)
  175. assert.Nil(t, err)
  176. _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
  177. assert.Nil(t, err)
  178. }
  179. rs := WithMiddlewares([]Middleware{
  180. func(next http.HandlerFunc) http.HandlerFunc {
  181. return func(w http.ResponseWriter, r *http.Request) {
  182. var v struct {
  183. Name string `path:"name"`
  184. Year string `path:"year"`
  185. }
  186. assert.Nil(t, httpx.ParsePath(r, &v))
  187. m[v.Name] = v.Year
  188. next.ServeHTTP(w, r)
  189. }
  190. },
  191. func(next http.HandlerFunc) http.HandlerFunc {
  192. return func(w http.ResponseWriter, r *http.Request) {
  193. var v struct {
  194. Name string `form:"nickname"`
  195. Zipcode string `form:"zipcode"`
  196. }
  197. assert.Nil(t, httpx.ParseForm(r, &v))
  198. assert.NotEmpty(t, m)
  199. m[v.Name] = v.Zipcode + v.Zipcode
  200. next.ServeHTTP(w, r)
  201. }
  202. },
  203. ToMiddleware(func(next http.Handler) http.Handler {
  204. return next
  205. }),
  206. }, Route{
  207. Method: http.MethodGet,
  208. Path: "/first/:name/:year",
  209. Handler: handler,
  210. }, Route{
  211. Method: http.MethodGet,
  212. Path: "/second/:name/:year",
  213. Handler: handler,
  214. })
  215. urls := []string{
  216. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  217. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  218. }
  219. for _, route := range rs {
  220. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  221. }
  222. for _, url := range urls {
  223. r, err := http.NewRequest(http.MethodGet, url, nil)
  224. assert.Nil(t, err)
  225. rr := httptest.NewRecorder()
  226. rt.ServeHTTP(rr, r)
  227. assert.Equal(t, "whatever:200000200000", rr.Body.String())
  228. }
  229. assert.EqualValues(t, map[string]string{
  230. "kevin": "2017",
  231. "wan": "2020",
  232. "whatever": "200000200000",
  233. }, m)
  234. }
  235. func TestWithPrefix(t *testing.T) {
  236. fr := featuredRoutes{
  237. routes: []Route{
  238. {
  239. Path: "/hello",
  240. },
  241. {
  242. Path: "/world",
  243. },
  244. },
  245. }
  246. WithPrefix("/api")(&fr)
  247. var vals []string
  248. for _, r := range fr.routes {
  249. vals = append(vals, r.Path)
  250. }
  251. assert.EqualValues(t, []string{"/api/hello", "/api/world"}, vals)
  252. }
  253. func TestWithPriority(t *testing.T) {
  254. var fr featuredRoutes
  255. WithPriority()(&fr)
  256. assert.True(t, fr.priority)
  257. }
  258. func TestWithTimeout(t *testing.T) {
  259. var fr featuredRoutes
  260. WithTimeout(time.Hour)(&fr)
  261. assert.Equal(t, time.Hour, fr.timeout)
  262. }
  263. func TestWithTLSConfig(t *testing.T) {
  264. const configYaml = `
  265. Name: foo
  266. Port: 54321
  267. `
  268. var cnf RestConf
  269. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  270. testConfig := &tls.Config{
  271. CipherSuites: []uint16{
  272. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  273. },
  274. }
  275. testCases := []struct {
  276. c RestConf
  277. opts []RunOption
  278. res *tls.Config
  279. }{
  280. {
  281. c: cnf,
  282. opts: []RunOption{WithTLSConfig(testConfig)},
  283. res: testConfig,
  284. },
  285. {
  286. c: cnf,
  287. opts: []RunOption{WithUnsignedCallback(nil)},
  288. res: nil,
  289. },
  290. }
  291. for _, testCase := range testCases {
  292. svr, err := NewServer(testCase.c, testCase.opts...)
  293. assert.Nil(t, err)
  294. assert.Equal(t, svr.ngin.tlsConfig, testCase.res)
  295. }
  296. }
  297. func TestWithCors(t *testing.T) {
  298. const configYaml = `
  299. Name: foo
  300. Port: 54321
  301. `
  302. var cnf RestConf
  303. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  304. rt := router.NewRouter()
  305. svr, err := NewServer(cnf, WithRouter(rt))
  306. assert.Nil(t, err)
  307. defer svr.Stop()
  308. opt := WithCors("local")
  309. opt(svr)
  310. }
  311. func TestWithCustomCors(t *testing.T) {
  312. const configYaml = `
  313. Name: foo
  314. Port: 54321
  315. `
  316. var cnf RestConf
  317. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  318. rt := router.NewRouter()
  319. svr, err := NewServer(cnf, WithRouter(rt))
  320. assert.Nil(t, err)
  321. opt := WithCustomCors(func(header http.Header) {
  322. header.Set("foo", "bar")
  323. }, func(w http.ResponseWriter) {
  324. w.WriteHeader(http.StatusOK)
  325. }, "local")
  326. opt(svr)
  327. }
  328. func TestServer_PrintRoutes(t *testing.T) {
  329. const (
  330. configYaml = `
  331. Name: foo
  332. Port: 54321
  333. `
  334. expect = `Routes:
  335. GET /bar
  336. GET /foo
  337. GET /foo/:bar
  338. GET /foo/:bar/baz
  339. `
  340. )
  341. var cnf RestConf
  342. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  343. svr, err := NewServer(cnf)
  344. assert.Nil(t, err)
  345. svr.AddRoutes([]Route{
  346. {
  347. Method: http.MethodGet,
  348. Path: "/foo",
  349. Handler: http.NotFound,
  350. },
  351. {
  352. Method: http.MethodGet,
  353. Path: "/bar",
  354. Handler: http.NotFound,
  355. },
  356. {
  357. Method: http.MethodGet,
  358. Path: "/foo/:bar",
  359. Handler: http.NotFound,
  360. },
  361. {
  362. Method: http.MethodGet,
  363. Path: "/foo/:bar/baz",
  364. Handler: http.NotFound,
  365. },
  366. })
  367. old := os.Stdout
  368. r, w, err := os.Pipe()
  369. assert.Nil(t, err)
  370. os.Stdout = w
  371. defer func() {
  372. os.Stdout = old
  373. }()
  374. svr.PrintRoutes()
  375. ch := make(chan string)
  376. go func() {
  377. var buf strings.Builder
  378. io.Copy(&buf, r)
  379. ch <- buf.String()
  380. }()
  381. w.Close()
  382. out := <-ch
  383. assert.Equal(t, expect, out)
  384. }
  385. func TestServer_Routes(t *testing.T) {
  386. const (
  387. configYaml = `
  388. Name: foo
  389. Port: 54321
  390. `
  391. expect = `GET /foo GET /bar GET /foo/:bar GET /foo/:bar/baz`
  392. )
  393. var cnf RestConf
  394. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  395. svr, err := NewServer(cnf)
  396. assert.Nil(t, err)
  397. svr.AddRoutes([]Route{
  398. {
  399. Method: http.MethodGet,
  400. Path: "/foo",
  401. Handler: http.NotFound,
  402. },
  403. {
  404. Method: http.MethodGet,
  405. Path: "/bar",
  406. Handler: http.NotFound,
  407. },
  408. {
  409. Method: http.MethodGet,
  410. Path: "/foo/:bar",
  411. Handler: http.NotFound,
  412. },
  413. {
  414. Method: http.MethodGet,
  415. Path: "/foo/:bar/baz",
  416. Handler: http.NotFound,
  417. },
  418. })
  419. routes := svr.Routes()
  420. var buf strings.Builder
  421. for i := 0; i < len(routes); i++ {
  422. buf.WriteString(routes[i].Method)
  423. buf.WriteString(" ")
  424. buf.WriteString(routes[i].Path)
  425. buf.WriteString(" ")
  426. }
  427. assert.Equal(t, expect, strings.Trim(buf.String(), " "))
  428. }
  429. func TestHandleError(t *testing.T) {
  430. assert.NotPanics(t, func() {
  431. handleError(nil)
  432. handleError(http.ErrServerClosed)
  433. })
  434. }
  435. func TestValidateSecret(t *testing.T) {
  436. assert.Panics(t, func() {
  437. validateSecret("short")
  438. })
  439. }
  440. func TestServer_WithChain(t *testing.T) {
  441. var called int32
  442. middleware1 := func() func(http.Handler) http.Handler {
  443. return func(next http.Handler) http.Handler {
  444. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  445. atomic.AddInt32(&called, 1)
  446. next.ServeHTTP(w, r)
  447. atomic.AddInt32(&called, 1)
  448. })
  449. }
  450. }
  451. middleware2 := func() func(http.Handler) http.Handler {
  452. return func(next http.Handler) http.Handler {
  453. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  454. atomic.AddInt32(&called, 1)
  455. next.ServeHTTP(w, r)
  456. atomic.AddInt32(&called, 1)
  457. })
  458. }
  459. }
  460. server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2())))
  461. server.AddRoutes(
  462. []Route{
  463. {
  464. Method: http.MethodGet,
  465. Path: "/",
  466. Handler: func(_ http.ResponseWriter, _ *http.Request) {
  467. atomic.AddInt32(&called, 1)
  468. },
  469. },
  470. },
  471. )
  472. rt := router.NewRouter()
  473. assert.Nil(t, server.ngin.bindRoutes(rt))
  474. req, err := http.NewRequest(http.MethodGet, "/", nil)
  475. assert.Nil(t, err)
  476. rt.ServeHTTP(httptest.NewRecorder(), req)
  477. assert.Equal(t, int32(5), atomic.LoadInt32(&called))
  478. }