|
@@ -0,0 +1,33 @@
|
|
|
+package security
|
|
|
+
|
|
|
+import (
|
|
|
+ "net/http"
|
|
|
+ "net/http/httptest"
|
|
|
+ "testing"
|
|
|
+
|
|
|
+ "github.com/stretchr/testify/assert"
|
|
|
+)
|
|
|
+
|
|
|
+func TestWithCodeResponseWriter(t *testing.T) {
|
|
|
+ req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
|
|
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ cw := &WithCodeResponseWriter{Writer: w}
|
|
|
+
|
|
|
+ cw.Header().Set("X-Test", "test")
|
|
|
+ cw.WriteHeader(http.StatusServiceUnavailable)
|
|
|
+ assert.Equal(t, cw.Code, http.StatusServiceUnavailable)
|
|
|
+
|
|
|
+ _, err := cw.Write([]byte("content"))
|
|
|
+ assert.Nil(t, err)
|
|
|
+
|
|
|
+ flusher, ok := http.ResponseWriter(cw).(http.Flusher)
|
|
|
+ assert.Equal(t, ok, true)
|
|
|
+ flusher.Flush()
|
|
|
+ })
|
|
|
+
|
|
|
+ resp := httptest.NewRecorder()
|
|
|
+ handler.ServeHTTP(resp, req)
|
|
|
+ assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
|
|
+ assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
|
|
+ assert.Equal(t, "content", resp.Body.String())
|
|
|
+}
|