|
@@ -3,25 +3,40 @@ package rest
|
|
import (
|
|
import (
|
|
"net/http"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/http/httptest"
|
|
- "strings"
|
|
|
|
"testing"
|
|
"testing"
|
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
)
|
|
|
|
|
|
-func TestCorsHandler(t *testing.T) {
|
|
|
|
- w := httptest.NewRecorder()
|
|
|
|
- handler := CorsHandler()
|
|
|
|
- handler.ServeHTTP(w, nil)
|
|
|
|
- assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
|
|
|
- assert.Equal(t, allOrigin, w.Header().Get(allowOrigin))
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
func TestCorsHandlerWithOrigins(t *testing.T) {
|
|
func TestCorsHandlerWithOrigins(t *testing.T) {
|
|
- origins := []string{"local", "remote"}
|
|
|
|
- w := httptest.NewRecorder()
|
|
|
|
- handler := CorsHandler(origins...)
|
|
|
|
- handler.ServeHTTP(w, nil)
|
|
|
|
- assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
|
|
|
- assert.Equal(t, strings.Join(origins, separator), w.Header().Get(allowOrigin))
|
|
|
|
|
|
+ tests := []struct {
|
|
|
|
+ name string
|
|
|
|
+ origins []string
|
|
|
|
+ expect string
|
|
|
|
+ }{
|
|
|
|
+ {
|
|
|
|
+ name: "allow all origins",
|
|
|
|
+ expect: allOrigins,
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "allow one origin",
|
|
|
|
+ origins: []string{"local"},
|
|
|
|
+ expect: "local",
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "allow many origins",
|
|
|
|
+ origins: []string{"local", "remote"},
|
|
|
|
+ expect: "local",
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for _, test := range tests {
|
|
|
|
+ t.Run(test.name, func(t *testing.T) {
|
|
|
|
+ w := httptest.NewRecorder()
|
|
|
|
+ handler := CorsHandler(test.origins...)
|
|
|
|
+ handler.ServeHTTP(w, nil)
|
|
|
|
+ assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
|
|
|
+ assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
|
|
|
+ })
|
|
|
|
+ }
|
|
}
|
|
}
|