|
@@ -7,11 +7,13 @@ import (
|
|
"io"
|
|
"io"
|
|
"net/http"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/http/httptest"
|
|
|
|
+ "strings"
|
|
"testing"
|
|
"testing"
|
|
"testing/iotest"
|
|
"testing/iotest"
|
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/zeromicro/go-zero/core/codec"
|
|
"github.com/zeromicro/go-zero/core/codec"
|
|
|
|
+ "github.com/zeromicro/go-zero/core/logx/logtest"
|
|
)
|
|
)
|
|
|
|
|
|
const (
|
|
const (
|
|
@@ -160,3 +162,84 @@ func TestCryptionHandler_BadKey(t *testing.T) {
|
|
err = decryptBody(maxBytes, append(aesKey, aesKey...), req)
|
|
err = decryptBody(maxBytes, append(aesKey, aesKey...), req)
|
|
assert.Error(t, err)
|
|
assert.Error(t, err)
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+func TestCryptionResponseWriter_Flush(t *testing.T) {
|
|
|
|
+ body := []byte("hello, world!")
|
|
|
|
+
|
|
|
|
+ t.Run("half", func(t *testing.T) {
|
|
|
|
+ recorder := httptest.NewRecorder()
|
|
|
|
+ f := flushableResponseWriter{
|
|
|
|
+ writer: &halfWriter{recorder},
|
|
|
|
+ }
|
|
|
|
+ w := newCryptionResponseWriter(f)
|
|
|
|
+ _, err := w.Write(body)
|
|
|
|
+ assert.NoError(t, err)
|
|
|
|
+ w.flush(aesKey)
|
|
|
|
+ b, err := io.ReadAll(recorder.Body)
|
|
|
|
+ assert.NoError(t, err)
|
|
|
|
+ expected, err := codec.EcbEncrypt(aesKey, body)
|
|
|
|
+ assert.NoError(t, err)
|
|
|
|
+ assert.True(t, strings.HasPrefix(base64.StdEncoding.EncodeToString(expected), string(b)))
|
|
|
|
+ assert.True(t, len(string(b)) < len(base64.StdEncoding.EncodeToString(expected)))
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ t.Run("full", func(t *testing.T) {
|
|
|
|
+ recorder := httptest.NewRecorder()
|
|
|
|
+ f := flushableResponseWriter{
|
|
|
|
+ writer: recorder,
|
|
|
|
+ }
|
|
|
|
+ w := newCryptionResponseWriter(f)
|
|
|
|
+ _, err := w.Write(body)
|
|
|
|
+ assert.NoError(t, err)
|
|
|
|
+ w.flush(aesKey)
|
|
|
|
+ b, err := io.ReadAll(recorder.Body)
|
|
|
|
+ assert.NoError(t, err)
|
|
|
|
+ expected, err := codec.EcbEncrypt(aesKey, body)
|
|
|
|
+ assert.NoError(t, err)
|
|
|
|
+ assert.Equal(t, base64.StdEncoding.EncodeToString(expected), string(b))
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ t.Run("bad writer", func(t *testing.T) {
|
|
|
|
+ buf := logtest.NewCollector(t)
|
|
|
|
+ f := flushableResponseWriter{
|
|
|
|
+ writer: new(badWriter),
|
|
|
|
+ }
|
|
|
|
+ w := newCryptionResponseWriter(f)
|
|
|
|
+ _, err := w.Write(body)
|
|
|
|
+ assert.NoError(t, err)
|
|
|
|
+ w.flush(aesKey)
|
|
|
|
+ assert.True(t, strings.Contains(buf.Content(), io.ErrClosedPipe.Error()))
|
|
|
|
+ })
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+type flushableResponseWriter struct {
|
|
|
|
+ writer io.Writer
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (m flushableResponseWriter) Header() http.Header {
|
|
|
|
+ panic("implement me")
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (m flushableResponseWriter) Write(p []byte) (int, error) {
|
|
|
|
+ return m.writer.Write(p)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (m flushableResponseWriter) WriteHeader(statusCode int) {
|
|
|
|
+ panic("implement me")
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+type halfWriter struct {
|
|
|
|
+ w io.Writer
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (t *halfWriter) Write(p []byte) (n int, err error) {
|
|
|
|
+ n = len(p) >> 1
|
|
|
|
+ return t.w.Write(p[0:n])
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+type badWriter struct {
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (b *badWriter) Write(p []byte) (n int, err error) {
|
|
|
|
+ return 0, io.ErrClosedPipe
|
|
|
|
+}
|