Ver código fonte

The ResponseWriters defined in rest.handler add Flush interface. (#318)

jichangyun 4 anos atrás
pai
commit
0bd2a0656c

+ 6 - 0
rest/handler/authhandler.go

@@ -138,3 +138,9 @@ func (grw *guardedResponseWriter) WriteHeader(statusCode int) {
 	grw.wroteHeader = true
 	grw.writer.WriteHeader(statusCode)
 }
+
+func (grw *guardedResponseWriter) Flush() {
+	if flusher, ok := grw.writer.(http.Flusher); ok {
+		flusher.Flush()
+	}
+}

+ 4 - 0
rest/handler/authhandler_test.go

@@ -41,6 +41,10 @@ func TestAuthHandler(t *testing.T) {
 			w.Header().Set("X-Test", "test")
 			_, err := w.Write([]byte("content"))
 			assert.Nil(t, err)
+
+			flusher, ok := w.(http.Flusher)
+			assert.Equal(t, ok, true)
+			flusher.Flush()
 		}))
 
 	resp := httptest.NewRecorder()

+ 6 - 0
rest/handler/cryptionhandler.go

@@ -95,6 +95,12 @@ func (w *cryptionResponseWriter) WriteHeader(statusCode int) {
 	w.ResponseWriter.WriteHeader(statusCode)
 }
 
+func (w *cryptionResponseWriter) Flush() {
+	if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
+		flusher.Flush()
+	}
+}
+
 func (w *cryptionResponseWriter) flush(key []byte) {
 	if w.buf.Len() == 0 {
 		return

+ 17 - 0
rest/handler/cryptionhandler_test.go

@@ -87,3 +87,20 @@ func TestCryptionHandlerWriteHeader(t *testing.T) {
 	handler.ServeHTTP(recorder, req)
 	assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
 }
+
+func TestCryptionHandlerFlush(t *testing.T) {
+	req := httptest.NewRequest(http.MethodGet, "/any", nil)
+	handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Write([]byte(respText))
+
+		flusher, ok := w.(http.Flusher)
+		assert.Equal(t, ok, true)
+		flusher.Flush()
+	}))
+	recorder := httptest.NewRecorder()
+	handler.ServeHTTP(recorder, req)
+
+	expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
+	assert.Nil(t, err)
+	assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
+}

+ 12 - 0
rest/handler/loghandler.go

@@ -38,6 +38,12 @@ func (w *LoggedResponseWriter) WriteHeader(code int) {
 	w.code = code
 }
 
+func (w *LoggedResponseWriter) Flush() {
+	if flusher, ok := w.w.(http.Flusher); ok {
+		flusher.Flush()
+	}
+}
+
 func LogHandler(next http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		timer := utils.NewElapsedTimer()
@@ -81,6 +87,12 @@ func (w *DetailLoggedResponseWriter) WriteHeader(code int) {
 	w.writer.WriteHeader(code)
 }
 
+func (w *DetailLoggedResponseWriter) Flush() {
+	if flusher, ok := http.ResponseWriter(w.writer).(http.Flusher); ok {
+		flusher.Flush()
+	}
+}
+
 func DetailedLogHandler(next http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		timer := utils.NewElapsedTimer()

+ 4 - 0
rest/handler/loghandler_test.go

@@ -30,6 +30,10 @@ func TestLogHandler(t *testing.T) {
 			w.WriteHeader(http.StatusServiceUnavailable)
 			_, err := w.Write([]byte("content"))
 			assert.Nil(t, err)
+
+			flusher, ok := w.(http.Flusher)
+			assert.Equal(t, ok, true)
+			flusher.Flush()
 		}))
 
 		resp := httptest.NewRecorder()

+ 6 - 0
rest/internal/security/withcoderesponsewriter.go

@@ -19,3 +19,9 @@ func (w *WithCodeResponseWriter) WriteHeader(code int) {
 	w.Writer.WriteHeader(code)
 	w.Code = code
 }
+
+func (w *WithCodeResponseWriter) Flush() {
+	if flusher, ok := w.Writer.(http.Flusher); ok {
+		flusher.Flush()
+	}
+}

+ 33 - 0
rest/internal/security/withcoderesponsewriter_test.go

@@ -0,0 +1,33 @@
+package security
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestWithCodeResponseWriter(t *testing.T) {
+	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
+	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		cw := &WithCodeResponseWriter{Writer: w}
+
+		cw.Header().Set("X-Test", "test")
+		cw.WriteHeader(http.StatusServiceUnavailable)
+		assert.Equal(t, cw.Code, http.StatusServiceUnavailable)
+
+		_, err := cw.Write([]byte("content"))
+		assert.Nil(t, err)
+
+		flusher, ok := http.ResponseWriter(cw).(http.Flusher)
+		assert.Equal(t, ok, true)
+		flusher.Flush()
+	})
+
+	resp := httptest.NewRecorder()
+	handler.ServeHTTP(resp, req)
+	assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
+	assert.Equal(t, "test", resp.Header().Get("X-Test"))
+	assert.Equal(t, "content", resp.Body.String())
+}