handlers_test.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package cors
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. "github.com/stretchr/testify/assert"
  7. )
  8. func TestCorsHandlerWithOrigins(t *testing.T) {
  9. tests := []struct {
  10. name string
  11. origins []string
  12. reqOrigin string
  13. expect string
  14. }{
  15. {
  16. name: "allow all origins",
  17. expect: allOrigins,
  18. },
  19. {
  20. name: "allow one origin",
  21. origins: []string{"http://local"},
  22. reqOrigin: "http://local",
  23. expect: "http://local",
  24. },
  25. {
  26. name: "allow many origins",
  27. origins: []string{"http://local", "http://remote"},
  28. reqOrigin: "http://local",
  29. expect: "http://local",
  30. },
  31. {
  32. name: "allow sub origins",
  33. origins: []string{"local", "remote"},
  34. reqOrigin: "sub.local",
  35. expect: "sub.local",
  36. },
  37. {
  38. name: "allow all origins",
  39. reqOrigin: "http://local",
  40. expect: "*",
  41. },
  42. {
  43. name: "allow many origins with all mark",
  44. origins: []string{"http://local", "http://remote", "*"},
  45. reqOrigin: "http://another",
  46. expect: "http://another",
  47. },
  48. {
  49. name: "not allow origin",
  50. origins: []string{"http://local", "http://remote"},
  51. reqOrigin: "http://another",
  52. },
  53. {
  54. name: "not safe origin",
  55. origins: []string{"safe.com"},
  56. reqOrigin: "not-safe.com",
  57. },
  58. }
  59. methods := []string{
  60. http.MethodOptions,
  61. http.MethodGet,
  62. http.MethodPost,
  63. }
  64. for _, test := range tests {
  65. for _, method := range methods {
  66. test := test
  67. t.Run(test.name+"-handler", func(t *testing.T) {
  68. r := httptest.NewRequest(method, "http://localhost", http.NoBody)
  69. r.Header.Set(originHeader, test.reqOrigin)
  70. w := httptest.NewRecorder()
  71. handler := NotAllowedHandler(nil, test.origins...)
  72. handler.ServeHTTP(w, r)
  73. if method == http.MethodOptions {
  74. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  75. } else {
  76. assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
  77. }
  78. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  79. })
  80. t.Run(test.name+"-handler-custom", func(t *testing.T) {
  81. r := httptest.NewRequest(method, "http://localhost", http.NoBody)
  82. r.Header.Set(originHeader, test.reqOrigin)
  83. w := httptest.NewRecorder()
  84. handler := NotAllowedHandler(func(w http.ResponseWriter) {
  85. w.Header().Set("foo", "bar")
  86. }, test.origins...)
  87. handler.ServeHTTP(w, r)
  88. if method == http.MethodOptions {
  89. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  90. } else {
  91. assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
  92. }
  93. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  94. assert.Equal(t, "bar", w.Header().Get("foo"))
  95. })
  96. }
  97. }
  98. for _, test := range tests {
  99. for _, method := range methods {
  100. test := test
  101. t.Run(test.name+"-middleware", func(t *testing.T) {
  102. r := httptest.NewRequest(method, "http://localhost", http.NoBody)
  103. r.Header.Set(originHeader, test.reqOrigin)
  104. w := httptest.NewRecorder()
  105. handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
  106. w.WriteHeader(http.StatusOK)
  107. })
  108. handler.ServeHTTP(w, r)
  109. if method == http.MethodOptions {
  110. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  111. } else {
  112. assert.Equal(t, http.StatusOK, w.Result().StatusCode)
  113. }
  114. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  115. })
  116. t.Run(test.name+"-middleware-custom", func(t *testing.T) {
  117. r := httptest.NewRequest(method, "http://localhost", http.NoBody)
  118. r.Header.Set(originHeader, test.reqOrigin)
  119. w := httptest.NewRecorder()
  120. handler := Middleware(func(header http.Header) {
  121. header.Set("foo", "bar")
  122. }, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
  123. w.WriteHeader(http.StatusOK)
  124. })
  125. handler.ServeHTTP(w, r)
  126. if method == http.MethodOptions {
  127. assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
  128. } else {
  129. assert.Equal(t, http.StatusOK, w.Result().StatusCode)
  130. }
  131. assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
  132. assert.Equal(t, "bar", w.Header().Get("foo"))
  133. })
  134. }
  135. }
  136. }