server_test.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/tal-tech/go-zero/core/conf"
  11. "github.com/tal-tech/go-zero/rest/httpx"
  12. "github.com/tal-tech/go-zero/rest/router"
  13. )
  14. func TestNewServer(t *testing.T) {
  15. const configYaml = `
  16. Name: foo
  17. Port: 54321
  18. `
  19. var cnf RestConf
  20. assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
  21. failStart := func(server *Server) {
  22. server.opts.start = func(e *engine) error {
  23. return http.ErrServerClosed
  24. }
  25. }
  26. tests := []struct {
  27. c RestConf
  28. opts []RunOption
  29. fail bool
  30. }{
  31. {
  32. c: RestConf{},
  33. opts: []RunOption{failStart},
  34. fail: true,
  35. },
  36. {
  37. c: cnf,
  38. opts: []RunOption{failStart},
  39. },
  40. {
  41. c: cnf,
  42. opts: []RunOption{WithNotAllowedHandler(nil), failStart},
  43. },
  44. {
  45. c: cnf,
  46. opts: []RunOption{WithNotFoundHandler(nil), failStart},
  47. },
  48. {
  49. c: cnf,
  50. opts: []RunOption{WithUnauthorizedCallback(nil), failStart},
  51. },
  52. {
  53. c: cnf,
  54. opts: []RunOption{WithUnsignedCallback(nil), failStart},
  55. },
  56. }
  57. for _, test := range tests {
  58. srv, err := NewServer(test.c, test.opts...)
  59. if test.fail {
  60. assert.NotNil(t, err)
  61. }
  62. if err != nil {
  63. continue
  64. }
  65. srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
  66. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  67. next.ServeHTTP(w, r)
  68. })
  69. }))
  70. srv.AddRoute(Route{
  71. Method: http.MethodGet,
  72. Path: "/",
  73. Handler: nil,
  74. }, WithJwt("thesecret"), WithSignature(SignatureConf{}),
  75. WithJwtTransition("preivous", "thenewone"))
  76. srv.Start()
  77. srv.Stop()
  78. }
  79. }
  80. func TestWithMiddleware(t *testing.T) {
  81. m := make(map[string]string)
  82. rt := router.NewRouter()
  83. handler := func(w http.ResponseWriter, r *http.Request) {
  84. var v struct {
  85. Nickname string `form:"nickname"`
  86. Zipcode int64 `form:"zipcode"`
  87. }
  88. err := httpx.Parse(r, &v)
  89. assert.Nil(t, err)
  90. _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode))
  91. assert.Nil(t, err)
  92. }
  93. rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc {
  94. return func(w http.ResponseWriter, r *http.Request) {
  95. var v struct {
  96. Name string `path:"name"`
  97. Year string `path:"year"`
  98. }
  99. assert.Nil(t, httpx.ParsePath(r, &v))
  100. m[v.Name] = v.Year
  101. next.ServeHTTP(w, r)
  102. }
  103. }, Route{
  104. Method: http.MethodGet,
  105. Path: "/first/:name/:year",
  106. Handler: handler,
  107. }, Route{
  108. Method: http.MethodGet,
  109. Path: "/second/:name/:year",
  110. Handler: handler,
  111. })
  112. urls := []string{
  113. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  114. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  115. }
  116. for _, route := range rs {
  117. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  118. }
  119. for _, url := range urls {
  120. r, err := http.NewRequest(http.MethodGet, url, nil)
  121. assert.Nil(t, err)
  122. rr := httptest.NewRecorder()
  123. rt.ServeHTTP(rr, r)
  124. assert.Equal(t, "whatever:200000", rr.Body.String())
  125. }
  126. assert.EqualValues(t, map[string]string{
  127. "kevin": "2017",
  128. "wan": "2020",
  129. }, m)
  130. }
  131. func TestMultiMiddlewares(t *testing.T) {
  132. m := make(map[string]string)
  133. rt := router.NewRouter()
  134. handler := func(w http.ResponseWriter, r *http.Request) {
  135. var v struct {
  136. Nickname string `form:"nickname"`
  137. Zipcode int64 `form:"zipcode"`
  138. }
  139. err := httpx.Parse(r, &v)
  140. assert.Nil(t, err)
  141. _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
  142. assert.Nil(t, err)
  143. }
  144. rs := WithMiddlewares([]Middleware{
  145. func(next http.HandlerFunc) http.HandlerFunc {
  146. return func(w http.ResponseWriter, r *http.Request) {
  147. var v struct {
  148. Name string `path:"name"`
  149. Year string `path:"year"`
  150. }
  151. assert.Nil(t, httpx.ParsePath(r, &v))
  152. m[v.Name] = v.Year
  153. next.ServeHTTP(w, r)
  154. }
  155. },
  156. func(next http.HandlerFunc) http.HandlerFunc {
  157. return func(w http.ResponseWriter, r *http.Request) {
  158. var v struct {
  159. Name string `form:"nickname"`
  160. Zipcode string `form:"zipcode"`
  161. }
  162. assert.Nil(t, httpx.ParseForm(r, &v))
  163. assert.NotEmpty(t, m)
  164. m[v.Name] = v.Zipcode + v.Zipcode
  165. next.ServeHTTP(w, r)
  166. }
  167. },
  168. }, Route{
  169. Method: http.MethodGet,
  170. Path: "/first/:name/:year",
  171. Handler: handler,
  172. }, Route{
  173. Method: http.MethodGet,
  174. Path: "/second/:name/:year",
  175. Handler: handler,
  176. })
  177. urls := []string{
  178. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  179. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  180. }
  181. for _, route := range rs {
  182. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  183. }
  184. for _, url := range urls {
  185. r, err := http.NewRequest(http.MethodGet, url, nil)
  186. assert.Nil(t, err)
  187. rr := httptest.NewRecorder()
  188. rt.ServeHTTP(rr, r)
  189. assert.Equal(t, "whatever:200000200000", rr.Body.String())
  190. }
  191. assert.EqualValues(t, map[string]string{
  192. "kevin": "2017",
  193. "wan": "2020",
  194. "whatever": "200000200000",
  195. }, m)
  196. }
  197. func TestWithPrefix(t *testing.T) {
  198. fr := featuredRoutes{
  199. routes: []Route{
  200. {
  201. Path: "/hello",
  202. },
  203. {
  204. Path: "/world",
  205. },
  206. },
  207. }
  208. WithPrefix("/api")(&fr)
  209. var vals []string
  210. for _, r := range fr.routes {
  211. vals = append(vals, r.Path)
  212. }
  213. assert.EqualValues(t, []string{"/api/hello", "/api/world"}, vals)
  214. }
  215. func TestWithPriority(t *testing.T) {
  216. var fr featuredRoutes
  217. WithPriority()(&fr)
  218. assert.True(t, fr.priority)
  219. }
  220. func TestWithTLSConfig(t *testing.T) {
  221. const configYaml = `
  222. Name: foo
  223. Port: 54321
  224. `
  225. var cnf RestConf
  226. assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
  227. testConfig := &tls.Config{
  228. CipherSuites: []uint16{
  229. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  230. },
  231. }
  232. testCases := []struct {
  233. c RestConf
  234. opts []RunOption
  235. res *tls.Config
  236. }{
  237. {
  238. c: cnf,
  239. opts: []RunOption{WithTLSConfig(testConfig)},
  240. res: testConfig,
  241. },
  242. {
  243. c: cnf,
  244. opts: []RunOption{WithUnsignedCallback(nil)},
  245. res: nil,
  246. },
  247. }
  248. for _, testCase := range testCases {
  249. srv, err := NewServer(testCase.c, testCase.opts...)
  250. assert.Nil(t, err)
  251. assert.Equal(t, srv.ngin.tlsConfig, testCase.res)
  252. }
  253. }