handlers_test.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. package cors
  2. import (
  3. "bufio"
  4. "net"
  5. "net/http"
  6. "net/http/httptest"
  7. "testing"
  8. "github.com/stretchr/testify/assert"
  9. )
  10. func TestCorsHandlerWithOrigins(t *testing.T) {
  11. tests := []struct {
  12. name string
  13. origins []string
  14. reqOrigin string
  15. expect string
  16. }{
  17. {
  18. name: "allow all origins",
  19. expect: allOrigins,
  20. },
  21. {
  22. name: "allow one origin",
  23. origins: []string{"http://local"},
  24. reqOrigin: "http://local",
  25. expect: "http://local",
  26. },
  27. {
  28. name: "allow many origins",
  29. origins: []string{"http://local", "http://remote"},
  30. reqOrigin: "http://local",
  31. expect: "http://local",
  32. },
  33. {
  34. name: "allow all origins",
  35. reqOrigin: "http://local",
  36. expect: "*",
  37. },
  38. {
  39. name: "allow many origins with all mark",
  40. origins: []string{"http://local", "http://remote", "*"},
  41. reqOrigin: "http://another",
  42. expect: "http://another",
  43. },
  44. {
  45. name: "not allow origin",
  46. origins: []string{"http://local", "http://remote"},
  47. reqOrigin: "http://another",
  48. },
  49. }
  50. methods := []string{
  51. http.MethodOptions,
  52. http.MethodGet,
  53. http.MethodPost,
  54. }
  55. for _, test := range tests {
  56. for _, method := range methods {
  57. test := test
  58. t.Run(test.name+"-handler", func(t *testing.T) {
  59. r := httptest.NewRequest(method, "http://localhost", nil)
  60. r.Header.Set(originHeader, test.reqOrigin)
  61. w := httptest.NewRecorder()
  62. handler := NotAllowedHandler(nil, test.origins...)
  63. handler.ServeHTTP(w, r)
  64. if method == http.MethodOptions {
  65. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  66. } else {
  67. assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
  68. }
  69. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  70. })
  71. t.Run(test.name+"-handler-custom", func(t *testing.T) {
  72. r := httptest.NewRequest(method, "http://localhost", nil)
  73. r.Header.Set(originHeader, test.reqOrigin)
  74. w := httptest.NewRecorder()
  75. handler := NotAllowedHandler(func(w http.ResponseWriter) {
  76. w.Header().Set("foo", "bar")
  77. }, test.origins...)
  78. handler.ServeHTTP(w, r)
  79. if method == http.MethodOptions {
  80. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  81. } else {
  82. assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
  83. }
  84. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  85. assert.Equal(t, "bar", w.Header().Get("foo"))
  86. })
  87. }
  88. }
  89. for _, test := range tests {
  90. for _, method := range methods {
  91. test := test
  92. t.Run(test.name+"-middleware", func(t *testing.T) {
  93. r := httptest.NewRequest(method, "http://localhost", nil)
  94. r.Header.Set(originHeader, test.reqOrigin)
  95. w := httptest.NewRecorder()
  96. handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
  97. w.WriteHeader(http.StatusOK)
  98. })
  99. handler.ServeHTTP(w, r)
  100. if method == http.MethodOptions {
  101. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  102. } else {
  103. assert.Equal(t, http.StatusOK, w.Result().StatusCode)
  104. }
  105. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  106. })
  107. t.Run(test.name+"-middleware-custom", func(t *testing.T) {
  108. r := httptest.NewRequest(method, "http://localhost", nil)
  109. r.Header.Set(originHeader, test.reqOrigin)
  110. w := httptest.NewRecorder()
  111. handler := Middleware(func(w http.ResponseWriter) {
  112. w.Header().Set("foo", "bar")
  113. }, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
  114. w.WriteHeader(http.StatusOK)
  115. })
  116. handler.ServeHTTP(w, r)
  117. if method == http.MethodOptions {
  118. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  119. } else {
  120. assert.Equal(t, http.StatusOK, w.Result().StatusCode)
  121. }
  122. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  123. assert.Equal(t, "bar", w.Header().Get("foo"))
  124. })
  125. }
  126. }
  127. }
  128. func TestGuardedResponseWriter_Flush(t *testing.T) {
  129. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  130. handler := NotAllowedHandler(func(w http.ResponseWriter) {
  131. w.Header().Set("X-Test", "test")
  132. w.WriteHeader(http.StatusServiceUnavailable)
  133. _, err := w.Write([]byte("content"))
  134. assert.Nil(t, err)
  135. flusher, ok := w.(http.Flusher)
  136. assert.True(t, ok)
  137. flusher.Flush()
  138. }, "foo.com")
  139. resp := httptest.NewRecorder()
  140. handler.ServeHTTP(resp, req)
  141. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  142. assert.Equal(t, "test", resp.Header().Get("X-Test"))
  143. assert.Equal(t, "content", resp.Body.String())
  144. }
  145. func TestGuardedResponseWriter_Hijack(t *testing.T) {
  146. resp := httptest.NewRecorder()
  147. writer := &guardedResponseWriter{
  148. w: resp,
  149. }
  150. assert.NotPanics(t, func() {
  151. writer.Hijack()
  152. })
  153. writer = &guardedResponseWriter{
  154. w: mockedHijackable{resp},
  155. }
  156. assert.NotPanics(t, func() {
  157. writer.Hijack()
  158. })
  159. }
  160. type mockedHijackable struct {
  161. *httptest.ResponseRecorder
  162. }
  163. func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  164. return nil, nil, nil
  165. }