handlers_test.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. package cors
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. "github.com/stretchr/testify/assert"
  7. )
  8. func TestCorsHandlerWithOrigins(t *testing.T) {
  9. tests := []struct {
  10. name string
  11. origins []string
  12. reqOrigin string
  13. expect string
  14. }{
  15. {
  16. name: "allow all origins",
  17. expect: allOrigins,
  18. },
  19. {
  20. name: "allow one origin",
  21. origins: []string{"http://local"},
  22. reqOrigin: "http://local",
  23. expect: "http://local",
  24. },
  25. {
  26. name: "allow many origins",
  27. origins: []string{"http://local", "http://remote"},
  28. reqOrigin: "http://local",
  29. expect: "http://local",
  30. },
  31. {
  32. name: "allow sub origins",
  33. origins: []string{"local", "remote"},
  34. reqOrigin: "sub.local",
  35. expect: "sub.local",
  36. },
  37. {
  38. name: "allow all origins",
  39. reqOrigin: "http://local",
  40. expect: "*",
  41. },
  42. {
  43. name: "allow many origins with all mark",
  44. origins: []string{"http://local", "http://remote", "*"},
  45. reqOrigin: "http://another",
  46. expect: "http://another",
  47. },
  48. {
  49. name: "not allow origin",
  50. origins: []string{"http://local", "http://remote"},
  51. reqOrigin: "http://another",
  52. },
  53. }
  54. methods := []string{
  55. http.MethodOptions,
  56. http.MethodGet,
  57. http.MethodPost,
  58. }
  59. for _, test := range tests {
  60. for _, method := range methods {
  61. test := test
  62. t.Run(test.name+"-handler", func(t *testing.T) {
  63. r := httptest.NewRequest(method, "http://localhost", nil)
  64. r.Header.Set(originHeader, test.reqOrigin)
  65. w := httptest.NewRecorder()
  66. handler := NotAllowedHandler(nil, test.origins...)
  67. handler.ServeHTTP(w, r)
  68. if method == http.MethodOptions {
  69. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  70. } else {
  71. assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
  72. }
  73. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  74. })
  75. t.Run(test.name+"-handler-custom", func(t *testing.T) {
  76. r := httptest.NewRequest(method, "http://localhost", nil)
  77. r.Header.Set(originHeader, test.reqOrigin)
  78. w := httptest.NewRecorder()
  79. handler := NotAllowedHandler(func(w http.ResponseWriter) {
  80. w.Header().Set("foo", "bar")
  81. }, test.origins...)
  82. handler.ServeHTTP(w, r)
  83. if method == http.MethodOptions {
  84. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  85. } else {
  86. assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
  87. }
  88. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  89. assert.Equal(t, "bar", w.Header().Get("foo"))
  90. })
  91. }
  92. }
  93. for _, test := range tests {
  94. for _, method := range methods {
  95. test := test
  96. t.Run(test.name+"-middleware", func(t *testing.T) {
  97. r := httptest.NewRequest(method, "http://localhost", nil)
  98. r.Header.Set(originHeader, test.reqOrigin)
  99. w := httptest.NewRecorder()
  100. handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
  101. w.WriteHeader(http.StatusOK)
  102. })
  103. handler.ServeHTTP(w, r)
  104. if method == http.MethodOptions {
  105. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  106. } else {
  107. assert.Equal(t, http.StatusOK, w.Result().StatusCode)
  108. }
  109. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  110. })
  111. t.Run(test.name+"-middleware-custom", func(t *testing.T) {
  112. r := httptest.NewRequest(method, "http://localhost", nil)
  113. r.Header.Set(originHeader, test.reqOrigin)
  114. w := httptest.NewRecorder()
  115. handler := Middleware(func(header http.Header) {
  116. header.Set("foo", "bar")
  117. }, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
  118. w.WriteHeader(http.StatusOK)
  119. })
  120. handler.ServeHTTP(w, r)
  121. if method == http.MethodOptions {
  122. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  123. } else {
  124. assert.Equal(t, http.StatusOK, w.Result().StatusCode)
  125. }
  126. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  127. assert.Equal(t, "bar", w.Header().Get("foo"))
  128. })
  129. }
  130. }
  131. }