Explorar o código

chore: add more tests (#3578)

Kevin Wan hai 1 ano
pai
achega
18d66a795d
Modificáronse 1 ficheiros con 83 adicións e 0 borrados
  1. 83 0
      rest/handler/cryptionhandler_test.go

+ 83 - 0
rest/handler/cryptionhandler_test.go

@@ -7,11 +7,13 @@ import (
 	"io"
 	"net/http"
 	"net/http/httptest"
+	"strings"
 	"testing"
 	"testing/iotest"
 
 	"github.com/stretchr/testify/assert"
 	"github.com/zeromicro/go-zero/core/codec"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 )
 
 const (
@@ -160,3 +162,84 @@ func TestCryptionHandler_BadKey(t *testing.T) {
 	err = decryptBody(maxBytes, append(aesKey, aesKey...), req)
 	assert.Error(t, err)
 }
+
+func TestCryptionResponseWriter_Flush(t *testing.T) {
+	body := []byte("hello, world!")
+
+	t.Run("half", func(t *testing.T) {
+		recorder := httptest.NewRecorder()
+		f := flushableResponseWriter{
+			writer: &halfWriter{recorder},
+		}
+		w := newCryptionResponseWriter(f)
+		_, err := w.Write(body)
+		assert.NoError(t, err)
+		w.flush(aesKey)
+		b, err := io.ReadAll(recorder.Body)
+		assert.NoError(t, err)
+		expected, err := codec.EcbEncrypt(aesKey, body)
+		assert.NoError(t, err)
+		assert.True(t, strings.HasPrefix(base64.StdEncoding.EncodeToString(expected), string(b)))
+		assert.True(t, len(string(b)) < len(base64.StdEncoding.EncodeToString(expected)))
+	})
+
+	t.Run("full", func(t *testing.T) {
+		recorder := httptest.NewRecorder()
+		f := flushableResponseWriter{
+			writer: recorder,
+		}
+		w := newCryptionResponseWriter(f)
+		_, err := w.Write(body)
+		assert.NoError(t, err)
+		w.flush(aesKey)
+		b, err := io.ReadAll(recorder.Body)
+		assert.NoError(t, err)
+		expected, err := codec.EcbEncrypt(aesKey, body)
+		assert.NoError(t, err)
+		assert.Equal(t, base64.StdEncoding.EncodeToString(expected), string(b))
+	})
+
+	t.Run("bad writer", func(t *testing.T) {
+		buf := logtest.NewCollector(t)
+		f := flushableResponseWriter{
+			writer: new(badWriter),
+		}
+		w := newCryptionResponseWriter(f)
+		_, err := w.Write(body)
+		assert.NoError(t, err)
+		w.flush(aesKey)
+		assert.True(t, strings.Contains(buf.Content(), io.ErrClosedPipe.Error()))
+	})
+}
+
+type flushableResponseWriter struct {
+	writer io.Writer
+}
+
+func (m flushableResponseWriter) Header() http.Header {
+	panic("implement me")
+}
+
+func (m flushableResponseWriter) Write(p []byte) (int, error) {
+	return m.writer.Write(p)
+}
+
+func (m flushableResponseWriter) WriteHeader(statusCode int) {
+	panic("implement me")
+}
+
+type halfWriter struct {
+	w io.Writer
+}
+
+func (t *halfWriter) Write(p []byte) (n int, err error) {
+	n = len(p) >> 1
+	return t.w.Write(p[0:n])
+}
+
+type badWriter struct {
+}
+
+func (b *badWriter) Write(p []byte) (n int, err error) {
+	return 0, io.ErrClosedPipe
+}