server_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package rest
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/http/httptest"
  7. "testing"
  8. "github.com/stretchr/testify/assert"
  9. "github.com/tal-tech/go-zero/rest/httpx"
  10. "github.com/tal-tech/go-zero/rest/router"
  11. )
  12. func TestWithMiddleware(t *testing.T) {
  13. m := make(map[string]string)
  14. router := router.NewPatRouter()
  15. handler := func(w http.ResponseWriter, r *http.Request) {
  16. var v struct {
  17. Nickname string `form:"nickname"`
  18. Zipcode int64 `form:"zipcode"`
  19. }
  20. err := httpx.Parse(r, &v)
  21. assert.Nil(t, err)
  22. _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode))
  23. assert.Nil(t, err)
  24. }
  25. rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc {
  26. return func(w http.ResponseWriter, r *http.Request) {
  27. var v struct {
  28. Name string `path:"name"`
  29. Year string `path:"year"`
  30. }
  31. assert.Nil(t, httpx.ParsePath(r, &v))
  32. m[v.Name] = v.Year
  33. next.ServeHTTP(w, r)
  34. }
  35. }, Route{
  36. Method: http.MethodGet,
  37. Path: "/first/:name/:year",
  38. Handler: handler,
  39. }, Route{
  40. Method: http.MethodGet,
  41. Path: "/second/:name/:year",
  42. Handler: handler,
  43. })
  44. urls := []string{
  45. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  46. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  47. }
  48. for _, route := range rs {
  49. assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
  50. }
  51. for _, url := range urls {
  52. r, err := http.NewRequest(http.MethodGet, url, nil)
  53. assert.Nil(t, err)
  54. rr := httptest.NewRecorder()
  55. router.ServeHTTP(rr, r)
  56. assert.Equal(t, "whatever:200000", rr.Body.String())
  57. }
  58. assert.EqualValues(t, map[string]string{
  59. "kevin": "2017",
  60. "wan": "2020",
  61. }, m)
  62. }
  63. func TestMultiMiddleware(t *testing.T) {
  64. m := make(map[string]string)
  65. router := router.NewPatRouter()
  66. handler := func(w http.ResponseWriter, r *http.Request) {
  67. var v struct {
  68. Nickname string `form:"nickname"`
  69. Zipcode int64 `form:"zipcode"`
  70. }
  71. err := httpx.Parse(r, &v)
  72. assert.Nil(t, err)
  73. _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
  74. assert.Nil(t, err)
  75. }
  76. rs := WithMiddlewares([]Middleware{
  77. func(next http.HandlerFunc) http.HandlerFunc {
  78. return func(w http.ResponseWriter, r *http.Request) {
  79. var v struct {
  80. Name string `path:"name"`
  81. Year string `path:"year"`
  82. }
  83. assert.Nil(t, httpx.ParsePath(r, &v))
  84. m[v.Name] = v.Year
  85. next.ServeHTTP(w, r)
  86. }
  87. },
  88. func(next http.HandlerFunc) http.HandlerFunc {
  89. return func(w http.ResponseWriter, r *http.Request) {
  90. var v struct {
  91. Name string `form:"nickname"`
  92. Zipcode string `form:"zipcode"`
  93. }
  94. assert.Nil(t, httpx.ParseForm(r, &v))
  95. assert.NotEmpty(t, m)
  96. m[v.Name] = v.Zipcode + v.Zipcode
  97. next.ServeHTTP(w, r)
  98. }
  99. },
  100. }, Route{
  101. Method: http.MethodGet,
  102. Path: "/first/:name/:year",
  103. Handler: handler,
  104. }, Route{
  105. Method: http.MethodGet,
  106. Path: "/second/:name/:year",
  107. Handler: handler,
  108. })
  109. urls := []string{
  110. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  111. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  112. }
  113. for _, route := range rs {
  114. assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
  115. }
  116. for _, url := range urls {
  117. r, err := http.NewRequest(http.MethodGet, url, nil)
  118. assert.Nil(t, err)
  119. rr := httptest.NewRecorder()
  120. router.ServeHTTP(rr, r)
  121. assert.Equal(t, "whatever:200000200000", rr.Body.String())
  122. }
  123. assert.EqualValues(t, map[string]string{
  124. "kevin": "2017",
  125. "wan": "2020",
  126. "whatever": "200000200000",
  127. }, m)
  128. }