123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- package handler
- import (
- "bytes"
- "crypto/rand"
- "encoding/base64"
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "testing/iotest"
- "github.com/stretchr/testify/assert"
- "github.com/wuntsong-org/go-zero-plus/core/codec"
- "github.com/wuntsong-org/go-zero-plus/core/logx/logtest"
- )
- const (
- reqText = "ping"
- respText = "pong"
- )
- var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
- func TestCryptionHandlerGet(t *testing.T) {
- req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
- handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _, err := w.Write([]byte(respText))
- w.Header().Set("X-Test", "test")
- assert.Nil(t, err)
- }))
- recorder := httptest.NewRecorder()
- handler.ServeHTTP(recorder, req)
- expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
- assert.Nil(t, err)
- assert.Equal(t, http.StatusOK, recorder.Code)
- assert.Equal(t, "test", recorder.Header().Get("X-Test"))
- assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
- }
- func TestCryptionHandlerGet_badKey(t *testing.T) {
- req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
- handler := CryptionHandler(append(aesKey, aesKey...))(http.HandlerFunc(
- func(w http.ResponseWriter, r *http.Request) {
- _, err := w.Write([]byte(respText))
- w.Header().Set("X-Test", "test")
- assert.Nil(t, err)
- }))
- recorder := httptest.NewRecorder()
- handler.ServeHTTP(recorder, req)
- assert.Equal(t, http.StatusInternalServerError, recorder.Code)
- }
- func TestCryptionHandlerPost(t *testing.T) {
- var buf bytes.Buffer
- enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
- assert.Nil(t, err)
- buf.WriteString(base64.StdEncoding.EncodeToString(enc))
- req := httptest.NewRequest(http.MethodPost, "/any", &buf)
- handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- body, err := io.ReadAll(r.Body)
- assert.Nil(t, err)
- assert.Equal(t, reqText, string(body))
- w.Write([]byte(respText))
- }))
- recorder := httptest.NewRecorder()
- handler.ServeHTTP(recorder, req)
- expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
- assert.Nil(t, err)
- assert.Equal(t, http.StatusOK, recorder.Code)
- assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
- }
- func TestCryptionHandlerPostBadEncryption(t *testing.T) {
- var buf bytes.Buffer
- enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
- assert.Nil(t, err)
- buf.Write(enc)
- req := httptest.NewRequest(http.MethodPost, "/any", &buf)
- handler := CryptionHandler(aesKey)(nil)
- recorder := httptest.NewRecorder()
- handler.ServeHTTP(recorder, req)
- assert.Equal(t, http.StatusBadRequest, recorder.Code)
- }
- func TestCryptionHandlerWriteHeader(t *testing.T) {
- req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
- handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusServiceUnavailable)
- }))
- recorder := httptest.NewRecorder()
- handler.ServeHTTP(recorder, req)
- assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
- }
- func TestCryptionHandlerFlush(t *testing.T) {
- req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
- handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte(respText))
- flusher, ok := w.(http.Flusher)
- assert.True(t, ok)
- flusher.Flush()
- }))
- recorder := httptest.NewRecorder()
- handler.ServeHTTP(recorder, req)
- expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
- assert.Nil(t, err)
- assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
- }
- func TestCryptionHandler_Hijack(t *testing.T) {
- resp := httptest.NewRecorder()
- writer := newCryptionResponseWriter(resp)
- assert.NotPanics(t, func() {
- writer.Hijack()
- })
- writer = newCryptionResponseWriter(mockedHijackable{resp})
- assert.NotPanics(t, func() {
- writer.Hijack()
- })
- }
- func TestCryptionHandler_ContentTooLong(t *testing.T) {
- handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- }))
- svr := httptest.NewServer(handler)
- defer svr.Close()
- body := make([]byte, maxBytes+1)
- _, err := rand.Read(body)
- assert.NoError(t, err)
- req, err := http.NewRequest(http.MethodPost, svr.URL, bytes.NewReader(body))
- assert.Nil(t, err)
- resp, err := http.DefaultClient.Do(req)
- assert.Nil(t, err)
- assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
- }
- func TestCryptionHandler_BadBody(t *testing.T) {
- req, err := http.NewRequest(http.MethodPost, "/foo", iotest.ErrReader(io.ErrUnexpectedEOF))
- assert.Nil(t, err)
- err = decryptBody(maxBytes, aesKey, req)
- assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
- }
- func TestCryptionHandler_BadKey(t *testing.T) {
- var buf bytes.Buffer
- enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
- assert.Nil(t, err)
- buf.WriteString(base64.StdEncoding.EncodeToString(enc))
- req := httptest.NewRequest(http.MethodPost, "/any", &buf)
- err = decryptBody(maxBytes, append(aesKey, aesKey...), req)
- assert.Error(t, err)
- }
- func TestCryptionResponseWriter_Flush(t *testing.T) {
- body := []byte("hello, world!")
- t.Run("half", func(t *testing.T) {
- recorder := httptest.NewRecorder()
- f := flushableResponseWriter{
- writer: &halfWriter{recorder},
- }
- w := newCryptionResponseWriter(f)
- _, err := w.Write(body)
- assert.NoError(t, err)
- w.flush(aesKey)
- b, err := io.ReadAll(recorder.Body)
- assert.NoError(t, err)
- expected, err := codec.EcbEncrypt(aesKey, body)
- assert.NoError(t, err)
- assert.True(t, strings.HasPrefix(base64.StdEncoding.EncodeToString(expected), string(b)))
- assert.True(t, len(string(b)) < len(base64.StdEncoding.EncodeToString(expected)))
- })
- t.Run("full", func(t *testing.T) {
- recorder := httptest.NewRecorder()
- f := flushableResponseWriter{
- writer: recorder,
- }
- w := newCryptionResponseWriter(f)
- _, err := w.Write(body)
- assert.NoError(t, err)
- w.flush(aesKey)
- b, err := io.ReadAll(recorder.Body)
- assert.NoError(t, err)
- expected, err := codec.EcbEncrypt(aesKey, body)
- assert.NoError(t, err)
- assert.Equal(t, base64.StdEncoding.EncodeToString(expected), string(b))
- })
- t.Run("bad writer", func(t *testing.T) {
- buf := logtest.NewCollector(t)
- f := flushableResponseWriter{
- writer: new(badWriter),
- }
- w := newCryptionResponseWriter(f)
- _, err := w.Write(body)
- assert.NoError(t, err)
- w.flush(aesKey)
- assert.True(t, strings.Contains(buf.Content(), io.ErrClosedPipe.Error()))
- })
- }
- type flushableResponseWriter struct {
- writer io.Writer
- }
- func (m flushableResponseWriter) Header() http.Header {
- panic("implement me")
- }
- func (m flushableResponseWriter) Write(p []byte) (int, error) {
- return m.writer.Write(p)
- }
- func (m flushableResponseWriter) WriteHeader(statusCode int) {
- panic("implement me")
- }
- type halfWriter struct {
- w io.Writer
- }
- func (t *halfWriter) Write(p []byte) (n int, err error) {
- n = len(p) >> 1
- return t.w.Write(p[0:n])
- }
- type badWriter struct {
- }
- func (b *badWriter) Write(p []byte) (n int, err error) {
- return 0, io.ErrClosedPipe
- }
|