patrouter_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package router
  2. import (
  3. "net/http"
  4. "testing"
  5. "github.com/stretchr/testify/assert"
  6. "zero/ngin/internal/context"
  7. )
  8. type mockedResponseWriter struct {
  9. code int
  10. }
  11. func (m *mockedResponseWriter) Header() http.Header {
  12. return http.Header{}
  13. }
  14. func (m *mockedResponseWriter) Write(p []byte) (int, error) {
  15. return len(p), nil
  16. }
  17. func (m *mockedResponseWriter) WriteHeader(code int) {
  18. m.code = code
  19. }
  20. func TestPatRouterHandleErrors(t *testing.T) {
  21. tests := []struct {
  22. method string
  23. path string
  24. err error
  25. }{
  26. {"FAKE", "", ErrInvalidMethod},
  27. {"GET", "", ErrInvalidPath},
  28. }
  29. for _, test := range tests {
  30. t.Run(test.method, func(t *testing.T) {
  31. router := NewPatRouter()
  32. err := router.Handle(test.method, test.path, nil)
  33. assert.Error(t, ErrInvalidMethod, err)
  34. })
  35. }
  36. }
  37. func TestPatRouterNotFound(t *testing.T) {
  38. var notFound bool
  39. router := NewPatRouter()
  40. router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  41. notFound = true
  42. }))
  43. router.Handle(http.MethodGet, "/a/b", nil)
  44. r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
  45. w := new(mockedResponseWriter)
  46. router.ServeHTTP(w, r)
  47. assert.True(t, notFound)
  48. }
  49. func TestPatRouter(t *testing.T) {
  50. tests := []struct {
  51. method string
  52. path string
  53. expect bool
  54. code int
  55. err error
  56. }{
  57. // we don't explicitly set status code, framework will do it.
  58. {http.MethodGet, "/a/b", true, 0, nil},
  59. {http.MethodGet, "/a/b/", true, 0, nil},
  60. {http.MethodGet, "/a/b?a=b", true, 0, nil},
  61. {http.MethodGet, "/a/b/?a=b", true, 0, nil},
  62. {http.MethodGet, "/a/b/c?a=b", true, 0, nil},
  63. {http.MethodGet, "/b/d", false, http.StatusNotFound, nil},
  64. }
  65. for _, test := range tests {
  66. t.Run(test.method+":"+test.path, func(t *testing.T) {
  67. routed := false
  68. router := NewPatRouter()
  69. err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  70. routed = true
  71. assert.Equal(t, 1, len(context.Vars(r)))
  72. }))
  73. assert.Nil(t, err)
  74. err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  75. routed = true
  76. assert.Nil(t, context.Vars(r))
  77. }))
  78. assert.Nil(t, err)
  79. err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  80. routed = true
  81. }))
  82. assert.Nil(t, err)
  83. w := new(mockedResponseWriter)
  84. r, _ := http.NewRequest(test.method, test.path, nil)
  85. router.ServeHTTP(w, r)
  86. assert.Equal(t, test.expect, routed)
  87. assert.Equal(t, test.code, w.code)
  88. if test.code == 0 {
  89. r, _ = http.NewRequest(http.MethodPut, test.path, nil)
  90. router.ServeHTTP(w, r)
  91. assert.Equal(t, http.StatusMethodNotAllowed, w.code)
  92. }
  93. })
  94. }
  95. }
  96. func BenchmarkPatRouter(b *testing.B) {
  97. b.ReportAllocs()
  98. router := NewPatRouter()
  99. router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  100. }))
  101. w := &mockedResponseWriter{}
  102. r, _ := http.NewRequest(http.MethodGet, "/api/a/b", nil)
  103. for i := 0; i < b.N; i++ {
  104. router.ServeHTTP(w, r)
  105. }
  106. }