patrouter_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. package httprouter
  2. import (
  3. "net/http"
  4. "testing"
  5. "github.com/stretchr/testify/assert"
  6. )
  7. type mockedResponseWriter struct {
  8. code int
  9. }
  10. func (m *mockedResponseWriter) Header() http.Header {
  11. return http.Header{}
  12. }
  13. func (m *mockedResponseWriter) Write(p []byte) (int, error) {
  14. return len(p), nil
  15. }
  16. func (m *mockedResponseWriter) WriteHeader(code int) {
  17. m.code = code
  18. }
  19. func TestPatRouterHandleErrors(t *testing.T) {
  20. tests := []struct {
  21. method string
  22. path string
  23. err error
  24. }{
  25. {"FAKE", "", ErrInvalidMethod},
  26. {"GET", "", ErrInvalidPath},
  27. }
  28. for _, test := range tests {
  29. t.Run(test.method, func(t *testing.T) {
  30. router := NewPatRouter()
  31. err := router.Handle(test.method, test.path, nil)
  32. assert.Error(t, ErrInvalidMethod, err)
  33. })
  34. }
  35. }
  36. func TestPatRouterNotFound(t *testing.T) {
  37. var notFound bool
  38. router := NewPatRouter()
  39. router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  40. notFound = true
  41. }))
  42. router.Handle(http.MethodGet, "/a/b", nil)
  43. r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
  44. w := new(mockedResponseWriter)
  45. router.ServeHTTP(w, r)
  46. assert.True(t, notFound)
  47. }
  48. func TestPatRouter(t *testing.T) {
  49. tests := []struct {
  50. method string
  51. path string
  52. expect bool
  53. code int
  54. err error
  55. }{
  56. // we don't explicitly set status code, framework will do it.
  57. {http.MethodGet, "/a/b", true, 0, nil},
  58. {http.MethodGet, "/a/b/", true, 0, nil},
  59. {http.MethodGet, "/a/b?a=b", true, 0, nil},
  60. {http.MethodGet, "/a/b/?a=b", true, 0, nil},
  61. {http.MethodGet, "/a/b/c?a=b", true, 0, nil},
  62. {http.MethodGet, "/b/d", false, http.StatusNotFound, nil},
  63. }
  64. for _, test := range tests {
  65. t.Run(test.method+":"+test.path, func(t *testing.T) {
  66. routed := false
  67. router := NewPatRouter()
  68. err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  69. routed = true
  70. assert.Equal(t, 1, len(Vars(r)))
  71. }))
  72. assert.Nil(t, err)
  73. err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  74. routed = true
  75. assert.Nil(t, Vars(r))
  76. }))
  77. assert.Nil(t, err)
  78. err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  79. routed = true
  80. }))
  81. assert.Nil(t, err)
  82. w := new(mockedResponseWriter)
  83. r, _ := http.NewRequest(test.method, test.path, nil)
  84. router.ServeHTTP(w, r)
  85. assert.Equal(t, test.expect, routed)
  86. assert.Equal(t, test.code, w.code)
  87. if test.code == 0 {
  88. r, _ = http.NewRequest(http.MethodPut, test.path, nil)
  89. router.ServeHTTP(w, r)
  90. assert.Equal(t, http.StatusMethodNotAllowed, w.code)
  91. }
  92. })
  93. }
  94. }
  95. func BenchmarkPatRouter(b *testing.B) {
  96. b.ReportAllocs()
  97. router := NewPatRouter()
  98. router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  99. }))
  100. w := &mockedResponseWriter{}
  101. r, _ := http.NewRequest(http.MethodGet, "/api/a/b", nil)
  102. for i := 0; i < b.N; i++ {
  103. router.ServeHTTP(w, r)
  104. }
  105. }