server_test.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net/http"
  8. "net/http/httptest"
  9. "os"
  10. "strings"
  11. "testing"
  12. "time"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/zeromicro/go-zero/core/conf"
  15. "github.com/zeromicro/go-zero/core/logx"
  16. "github.com/zeromicro/go-zero/rest/httpx"
  17. "github.com/zeromicro/go-zero/rest/router"
  18. )
  19. func TestNewServer(t *testing.T) {
  20. writer := logx.Reset()
  21. defer logx.SetWriter(writer)
  22. logx.SetWriter(logx.NewWriter(ioutil.Discard))
  23. const configYaml = `
  24. Name: foo
  25. Port: 54321
  26. `
  27. var cnf RestConf
  28. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  29. tests := []struct {
  30. c RestConf
  31. opts []RunOption
  32. fail bool
  33. }{
  34. {
  35. c: RestConf{},
  36. opts: []RunOption{WithRouter(mockedRouter{}), WithCors()},
  37. },
  38. {
  39. c: cnf,
  40. opts: []RunOption{WithRouter(mockedRouter{})},
  41. },
  42. {
  43. c: cnf,
  44. opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)},
  45. },
  46. {
  47. c: cnf,
  48. opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})},
  49. },
  50. {
  51. c: cnf,
  52. opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})},
  53. },
  54. {
  55. c: cnf,
  56. opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})},
  57. },
  58. }
  59. for _, test := range tests {
  60. var svr *Server
  61. var err error
  62. if test.fail {
  63. _, err = NewServer(test.c, test.opts...)
  64. assert.NotNil(t, err)
  65. continue
  66. } else {
  67. svr = MustNewServer(test.c, test.opts...)
  68. }
  69. svr.Use(ToMiddleware(func(next http.Handler) http.Handler {
  70. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  71. next.ServeHTTP(w, r)
  72. })
  73. }))
  74. svr.AddRoute(Route{
  75. Method: http.MethodGet,
  76. Path: "/",
  77. Handler: nil,
  78. }, WithJwt("thesecret"), WithSignature(SignatureConf{}),
  79. WithJwtTransition("preivous", "thenewone"))
  80. func() {
  81. defer func() {
  82. p := recover()
  83. switch v := p.(type) {
  84. case error:
  85. assert.Equal(t, "foo", v.Error())
  86. default:
  87. t.Fail()
  88. }
  89. }()
  90. svr.Start()
  91. svr.Stop()
  92. }()
  93. }
  94. }
  95. func TestWithMaxBytes(t *testing.T) {
  96. const maxBytes = 1000
  97. var fr featuredRoutes
  98. WithMaxBytes(maxBytes)(&fr)
  99. assert.Equal(t, int64(maxBytes), fr.maxBytes)
  100. }
  101. func TestWithMiddleware(t *testing.T) {
  102. m := make(map[string]string)
  103. rt := router.NewRouter()
  104. handler := func(w http.ResponseWriter, r *http.Request) {
  105. var v struct {
  106. Nickname string `form:"nickname"`
  107. Zipcode int64 `form:"zipcode"`
  108. }
  109. err := httpx.Parse(r, &v)
  110. assert.Nil(t, err)
  111. _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode))
  112. assert.Nil(t, err)
  113. }
  114. rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc {
  115. return func(w http.ResponseWriter, r *http.Request) {
  116. var v struct {
  117. Name string `path:"name"`
  118. Year string `path:"year"`
  119. }
  120. assert.Nil(t, httpx.ParsePath(r, &v))
  121. m[v.Name] = v.Year
  122. next.ServeHTTP(w, r)
  123. }
  124. }, Route{
  125. Method: http.MethodGet,
  126. Path: "/first/:name/:year",
  127. Handler: handler,
  128. }, Route{
  129. Method: http.MethodGet,
  130. Path: "/second/:name/:year",
  131. Handler: handler,
  132. })
  133. urls := []string{
  134. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  135. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  136. }
  137. for _, route := range rs {
  138. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  139. }
  140. for _, url := range urls {
  141. r, err := http.NewRequest(http.MethodGet, url, nil)
  142. assert.Nil(t, err)
  143. rr := httptest.NewRecorder()
  144. rt.ServeHTTP(rr, r)
  145. assert.Equal(t, "whatever:200000", rr.Body.String())
  146. }
  147. assert.EqualValues(t, map[string]string{
  148. "kevin": "2017",
  149. "wan": "2020",
  150. }, m)
  151. }
  152. func TestMultiMiddlewares(t *testing.T) {
  153. m := make(map[string]string)
  154. rt := router.NewRouter()
  155. handler := func(w http.ResponseWriter, r *http.Request) {
  156. var v struct {
  157. Nickname string `form:"nickname"`
  158. Zipcode int64 `form:"zipcode"`
  159. }
  160. err := httpx.Parse(r, &v)
  161. assert.Nil(t, err)
  162. _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
  163. assert.Nil(t, err)
  164. }
  165. rs := WithMiddlewares([]Middleware{
  166. func(next http.HandlerFunc) http.HandlerFunc {
  167. return func(w http.ResponseWriter, r *http.Request) {
  168. var v struct {
  169. Name string `path:"name"`
  170. Year string `path:"year"`
  171. }
  172. assert.Nil(t, httpx.ParsePath(r, &v))
  173. m[v.Name] = v.Year
  174. next.ServeHTTP(w, r)
  175. }
  176. },
  177. func(next http.HandlerFunc) http.HandlerFunc {
  178. return func(w http.ResponseWriter, r *http.Request) {
  179. var v struct {
  180. Name string `form:"nickname"`
  181. Zipcode string `form:"zipcode"`
  182. }
  183. assert.Nil(t, httpx.ParseForm(r, &v))
  184. assert.NotEmpty(t, m)
  185. m[v.Name] = v.Zipcode + v.Zipcode
  186. next.ServeHTTP(w, r)
  187. }
  188. },
  189. ToMiddleware(func(next http.Handler) http.Handler {
  190. return next
  191. }),
  192. }, Route{
  193. Method: http.MethodGet,
  194. Path: "/first/:name/:year",
  195. Handler: handler,
  196. }, Route{
  197. Method: http.MethodGet,
  198. Path: "/second/:name/:year",
  199. Handler: handler,
  200. })
  201. urls := []string{
  202. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  203. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  204. }
  205. for _, route := range rs {
  206. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  207. }
  208. for _, url := range urls {
  209. r, err := http.NewRequest(http.MethodGet, url, nil)
  210. assert.Nil(t, err)
  211. rr := httptest.NewRecorder()
  212. rt.ServeHTTP(rr, r)
  213. assert.Equal(t, "whatever:200000200000", rr.Body.String())
  214. }
  215. assert.EqualValues(t, map[string]string{
  216. "kevin": "2017",
  217. "wan": "2020",
  218. "whatever": "200000200000",
  219. }, m)
  220. }
  221. func TestWithPrefix(t *testing.T) {
  222. fr := featuredRoutes{
  223. routes: []Route{
  224. {
  225. Path: "/hello",
  226. },
  227. {
  228. Path: "/world",
  229. },
  230. },
  231. }
  232. WithPrefix("/api")(&fr)
  233. var vals []string
  234. for _, r := range fr.routes {
  235. vals = append(vals, r.Path)
  236. }
  237. assert.EqualValues(t, []string{"/api/hello", "/api/world"}, vals)
  238. }
  239. func TestWithPriority(t *testing.T) {
  240. var fr featuredRoutes
  241. WithPriority()(&fr)
  242. assert.True(t, fr.priority)
  243. }
  244. func TestWithTimeout(t *testing.T) {
  245. var fr featuredRoutes
  246. WithTimeout(time.Hour)(&fr)
  247. assert.Equal(t, time.Hour, fr.timeout)
  248. }
  249. func TestWithTLSConfig(t *testing.T) {
  250. const configYaml = `
  251. Name: foo
  252. Port: 54321
  253. `
  254. var cnf RestConf
  255. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  256. testConfig := &tls.Config{
  257. CipherSuites: []uint16{
  258. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  259. },
  260. }
  261. testCases := []struct {
  262. c RestConf
  263. opts []RunOption
  264. res *tls.Config
  265. }{
  266. {
  267. c: cnf,
  268. opts: []RunOption{WithTLSConfig(testConfig)},
  269. res: testConfig,
  270. },
  271. {
  272. c: cnf,
  273. opts: []RunOption{WithUnsignedCallback(nil)},
  274. res: nil,
  275. },
  276. }
  277. for _, testCase := range testCases {
  278. svr, err := NewServer(testCase.c, testCase.opts...)
  279. assert.Nil(t, err)
  280. assert.Equal(t, svr.ngin.tlsConfig, testCase.res)
  281. }
  282. }
  283. func TestWithCors(t *testing.T) {
  284. const configYaml = `
  285. Name: foo
  286. Port: 54321
  287. `
  288. var cnf RestConf
  289. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  290. rt := router.NewRouter()
  291. svr, err := NewServer(cnf, WithRouter(rt))
  292. assert.Nil(t, err)
  293. opt := WithCors("local")
  294. opt(svr)
  295. }
  296. func TestWithCustomCors(t *testing.T) {
  297. const configYaml = `
  298. Name: foo
  299. Port: 54321
  300. `
  301. var cnf RestConf
  302. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  303. rt := router.NewRouter()
  304. svr, err := NewServer(cnf, WithRouter(rt))
  305. assert.Nil(t, err)
  306. opt := WithCustomCors(func(header http.Header) {
  307. header.Set("foo", "bar")
  308. }, func(w http.ResponseWriter) {
  309. w.WriteHeader(http.StatusOK)
  310. }, "local")
  311. opt(svr)
  312. }
  313. func TestServer_PrintRoutes(t *testing.T) {
  314. const (
  315. configYaml = `
  316. Name: foo
  317. Port: 54321
  318. `
  319. expect = `GET /bar
  320. GET /foo
  321. GET /foo/:bar
  322. GET /foo/:bar/baz
  323. `
  324. )
  325. var cnf RestConf
  326. assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
  327. svr, err := NewServer(cnf)
  328. assert.Nil(t, err)
  329. svr.AddRoutes([]Route{
  330. {
  331. Method: http.MethodGet,
  332. Path: "/foo",
  333. Handler: http.NotFound,
  334. },
  335. {
  336. Method: http.MethodGet,
  337. Path: "/bar",
  338. Handler: http.NotFound,
  339. },
  340. {
  341. Method: http.MethodGet,
  342. Path: "/foo/:bar",
  343. Handler: http.NotFound,
  344. },
  345. {
  346. Method: http.MethodGet,
  347. Path: "/foo/:bar/baz",
  348. Handler: http.NotFound,
  349. },
  350. })
  351. old := os.Stdout
  352. r, w, err := os.Pipe()
  353. assert.Nil(t, err)
  354. os.Stdout = w
  355. defer func() {
  356. os.Stdout = old
  357. }()
  358. svr.PrintRoutes()
  359. ch := make(chan string)
  360. go func() {
  361. var buf strings.Builder
  362. io.Copy(&buf, r)
  363. ch <- buf.String()
  364. }()
  365. w.Close()
  366. out := <-ch
  367. assert.Equal(t, expect, out)
  368. }