handlers_test.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. package cors
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. "github.com/stretchr/testify/assert"
  7. )
  8. func TestCorsHandlerWithOrigins(t *testing.T) {
  9. tests := []struct {
  10. name string
  11. origins []string
  12. reqOrigin string
  13. expect string
  14. }{
  15. {
  16. name: "allow all origins",
  17. expect: allOrigins,
  18. },
  19. {
  20. name: "allow one origin",
  21. origins: []string{"http://local"},
  22. reqOrigin: "http://local",
  23. expect: "http://local",
  24. },
  25. {
  26. name: "allow many origins",
  27. origins: []string{"http://local", "http://remote"},
  28. reqOrigin: "http://local",
  29. expect: "http://local",
  30. },
  31. {
  32. name: "allow all origins",
  33. reqOrigin: "http://local",
  34. expect: "*",
  35. },
  36. {
  37. name: "allow many origins with all mark",
  38. origins: []string{"http://local", "http://remote", "*"},
  39. reqOrigin: "http://another",
  40. expect: "http://another",
  41. },
  42. {
  43. name: "not allow origin",
  44. origins: []string{"http://local", "http://remote"},
  45. reqOrigin: "http://another",
  46. },
  47. }
  48. methods := []string{
  49. http.MethodOptions,
  50. http.MethodGet,
  51. http.MethodPost,
  52. }
  53. for _, test := range tests {
  54. for _, method := range methods {
  55. test := test
  56. t.Run(test.name+"-handler", func(t *testing.T) {
  57. r := httptest.NewRequest(method, "http://localhost", nil)
  58. r.Header.Set(originHeader, test.reqOrigin)
  59. w := httptest.NewRecorder()
  60. handler := NotAllowedHandler(nil, test.origins...)
  61. handler.ServeHTTP(w, r)
  62. if method == http.MethodOptions {
  63. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  64. } else {
  65. assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
  66. }
  67. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  68. })
  69. t.Run(test.name+"-handler-custom", func(t *testing.T) {
  70. r := httptest.NewRequest(method, "http://localhost", nil)
  71. r.Header.Set(originHeader, test.reqOrigin)
  72. w := httptest.NewRecorder()
  73. handler := NotAllowedHandler(func(w http.ResponseWriter) {
  74. w.Header().Set("foo", "bar")
  75. }, test.origins...)
  76. handler.ServeHTTP(w, r)
  77. if method == http.MethodOptions {
  78. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  79. } else {
  80. assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
  81. }
  82. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  83. assert.Equal(t, "bar", w.Header().Get("foo"))
  84. })
  85. }
  86. }
  87. for _, test := range tests {
  88. for _, method := range methods {
  89. test := test
  90. t.Run(test.name+"-middleware", func(t *testing.T) {
  91. r := httptest.NewRequest(method, "http://localhost", nil)
  92. r.Header.Set(originHeader, test.reqOrigin)
  93. w := httptest.NewRecorder()
  94. handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
  95. w.WriteHeader(http.StatusOK)
  96. })
  97. handler.ServeHTTP(w, r)
  98. if method == http.MethodOptions {
  99. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  100. } else {
  101. assert.Equal(t, http.StatusOK, w.Result().StatusCode)
  102. }
  103. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  104. })
  105. t.Run(test.name+"-middleware-custom", func(t *testing.T) {
  106. r := httptest.NewRequest(method, "http://localhost", nil)
  107. r.Header.Set(originHeader, test.reqOrigin)
  108. w := httptest.NewRecorder()
  109. handler := Middleware(func(header http.Header) {
  110. header.Set("foo", "bar")
  111. }, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
  112. w.WriteHeader(http.StatusOK)
  113. })
  114. handler.ServeHTTP(w, r)
  115. if method == http.MethodOptions {
  116. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  117. } else {
  118. assert.Equal(t, http.StatusOK, w.Result().StatusCode)
  119. }
  120. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  121. assert.Equal(t, "bar", w.Header().Get("foo"))
  122. })
  123. }
  124. }
  125. }