server_test.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package rest
  2. import (
  3. "crypto/tls"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/tal-tech/go-zero/core/conf"
  11. "github.com/tal-tech/go-zero/rest/httpx"
  12. "github.com/tal-tech/go-zero/rest/router"
  13. )
  14. func TestNewServer(t *testing.T) {
  15. const configYaml = `
  16. Name: foo
  17. Port: 54321
  18. `
  19. var cnf RestConf
  20. assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
  21. failStart := func(server *Server) {
  22. server.opts.start = func(e *engine) error {
  23. return http.ErrServerClosed
  24. }
  25. }
  26. tests := []struct {
  27. c RestConf
  28. opts []RunOption
  29. fail bool
  30. }{
  31. {
  32. c: RestConf{},
  33. opts: []RunOption{failStart},
  34. fail: true,
  35. },
  36. {
  37. c: cnf,
  38. opts: []RunOption{failStart},
  39. },
  40. {
  41. c: cnf,
  42. opts: []RunOption{WithNotAllowedHandler(nil), failStart},
  43. },
  44. {
  45. c: cnf,
  46. opts: []RunOption{WithNotFoundHandler(nil), failStart},
  47. },
  48. {
  49. c: cnf,
  50. opts: []RunOption{WithUnauthorizedCallback(nil), failStart},
  51. },
  52. {
  53. c: cnf,
  54. opts: []RunOption{WithUnsignedCallback(nil), failStart},
  55. },
  56. }
  57. for _, test := range tests {
  58. srv, err := NewServer(test.c, test.opts...)
  59. if test.fail {
  60. assert.NotNil(t, err)
  61. }
  62. if err != nil {
  63. continue
  64. }
  65. srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
  66. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  67. next.ServeHTTP(w, r)
  68. })
  69. }))
  70. srv.AddRoute(Route{
  71. Method: http.MethodGet,
  72. Path: "/",
  73. Handler: nil,
  74. }, WithJwt("thesecret"), WithSignature(SignatureConf{}),
  75. WithJwtTransition("preivous", "thenewone"))
  76. srv.Start()
  77. srv.Stop()
  78. }
  79. }
  80. func TestWithMiddleware(t *testing.T) {
  81. m := make(map[string]string)
  82. rt := router.NewRouter()
  83. handler := func(w http.ResponseWriter, r *http.Request) {
  84. var v struct {
  85. Nickname string `form:"nickname"`
  86. Zipcode int64 `form:"zipcode"`
  87. }
  88. err := httpx.Parse(r, &v)
  89. assert.Nil(t, err)
  90. _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode))
  91. assert.Nil(t, err)
  92. }
  93. rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc {
  94. return func(w http.ResponseWriter, r *http.Request) {
  95. var v struct {
  96. Name string `path:"name"`
  97. Year string `path:"year"`
  98. }
  99. assert.Nil(t, httpx.ParsePath(r, &v))
  100. m[v.Name] = v.Year
  101. next.ServeHTTP(w, r)
  102. }
  103. }, Route{
  104. Method: http.MethodGet,
  105. Path: "/first/:name/:year",
  106. Handler: handler,
  107. }, Route{
  108. Method: http.MethodGet,
  109. Path: "/second/:name/:year",
  110. Handler: handler,
  111. })
  112. urls := []string{
  113. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  114. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  115. }
  116. for _, route := range rs {
  117. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  118. }
  119. for _, url := range urls {
  120. r, err := http.NewRequest(http.MethodGet, url, nil)
  121. assert.Nil(t, err)
  122. rr := httptest.NewRecorder()
  123. rt.ServeHTTP(rr, r)
  124. assert.Equal(t, "whatever:200000", rr.Body.String())
  125. }
  126. assert.EqualValues(t, map[string]string{
  127. "kevin": "2017",
  128. "wan": "2020",
  129. }, m)
  130. }
  131. func TestMultiMiddlewares(t *testing.T) {
  132. m := make(map[string]string)
  133. rt := router.NewRouter()
  134. handler := func(w http.ResponseWriter, r *http.Request) {
  135. var v struct {
  136. Nickname string `form:"nickname"`
  137. Zipcode int64 `form:"zipcode"`
  138. }
  139. err := httpx.Parse(r, &v)
  140. assert.Nil(t, err)
  141. _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
  142. assert.Nil(t, err)
  143. }
  144. rs := WithMiddlewares([]Middleware{
  145. func(next http.HandlerFunc) http.HandlerFunc {
  146. return func(w http.ResponseWriter, r *http.Request) {
  147. var v struct {
  148. Name string `path:"name"`
  149. Year string `path:"year"`
  150. }
  151. assert.Nil(t, httpx.ParsePath(r, &v))
  152. m[v.Name] = v.Year
  153. next.ServeHTTP(w, r)
  154. }
  155. },
  156. func(next http.HandlerFunc) http.HandlerFunc {
  157. return func(w http.ResponseWriter, r *http.Request) {
  158. var v struct {
  159. Name string `form:"nickname"`
  160. Zipcode string `form:"zipcode"`
  161. }
  162. assert.Nil(t, httpx.ParseForm(r, &v))
  163. assert.NotEmpty(t, m)
  164. m[v.Name] = v.Zipcode + v.Zipcode
  165. next.ServeHTTP(w, r)
  166. }
  167. },
  168. }, Route{
  169. Method: http.MethodGet,
  170. Path: "/first/:name/:year",
  171. Handler: handler,
  172. }, Route{
  173. Method: http.MethodGet,
  174. Path: "/second/:name/:year",
  175. Handler: handler,
  176. })
  177. urls := []string{
  178. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  179. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  180. }
  181. for _, route := range rs {
  182. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  183. }
  184. for _, url := range urls {
  185. r, err := http.NewRequest(http.MethodGet, url, nil)
  186. assert.Nil(t, err)
  187. rr := httptest.NewRecorder()
  188. rt.ServeHTTP(rr, r)
  189. assert.Equal(t, "whatever:200000200000", rr.Body.String())
  190. }
  191. assert.EqualValues(t, map[string]string{
  192. "kevin": "2017",
  193. "wan": "2020",
  194. "whatever": "200000200000",
  195. }, m)
  196. }
  197. func TestWithPriority(t *testing.T) {
  198. var fr featuredRoutes
  199. WithPriority()(&fr)
  200. assert.True(t, fr.priority)
  201. }
  202. func TestWithTLSConfig(t *testing.T) {
  203. const configYaml = `
  204. Name: foo
  205. Port: 54321
  206. `
  207. var cnf RestConf
  208. assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
  209. testConfig := []uint16{
  210. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  211. }
  212. testCases := []struct {
  213. c RestConf
  214. opts []RunOption
  215. res *tls.Config
  216. }{
  217. {
  218. c: cnf,
  219. opts: []RunOption{WithTLSConfig(testConfig)},
  220. res: &tls.Config{CipherSuites: testConfig},
  221. },
  222. {
  223. c: cnf,
  224. opts: []RunOption{WithUnsignedCallback(nil)},
  225. res: nil,
  226. },
  227. }
  228. for _, testCase := range testCases {
  229. srv, err := NewServer(testCase.c, testCase.opts...)
  230. assert.Nil(t, err)
  231. assert.Equal(t, srv.ngin.tlsConfig, testCase.res)
  232. }
  233. }