cryptionhandler_test.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package handler
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "io"
  6. "log"
  7. "math/rand"
  8. "net/http"
  9. "net/http/httptest"
  10. "testing"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/zeromicro/go-zero/core/codec"
  13. )
  14. const (
  15. reqText = "ping"
  16. respText = "pong"
  17. )
  18. var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
  19. func init() {
  20. log.SetOutput(io.Discard)
  21. }
  22. func TestCryptionHandlerGet(t *testing.T) {
  23. req := httptest.NewRequest(http.MethodGet, "/any", nil)
  24. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  25. _, err := w.Write([]byte(respText))
  26. w.Header().Set("X-Test", "test")
  27. assert.Nil(t, err)
  28. }))
  29. recorder := httptest.NewRecorder()
  30. handler.ServeHTTP(recorder, req)
  31. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  32. assert.Nil(t, err)
  33. assert.Equal(t, http.StatusOK, recorder.Code)
  34. assert.Equal(t, "test", recorder.Header().Get("X-Test"))
  35. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  36. }
  37. func TestCryptionHandlerPost(t *testing.T) {
  38. var buf bytes.Buffer
  39. enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
  40. assert.Nil(t, err)
  41. buf.WriteString(base64.StdEncoding.EncodeToString(enc))
  42. req := httptest.NewRequest(http.MethodPost, "/any", &buf)
  43. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  44. body, err := io.ReadAll(r.Body)
  45. assert.Nil(t, err)
  46. assert.Equal(t, reqText, string(body))
  47. w.Write([]byte(respText))
  48. }))
  49. recorder := httptest.NewRecorder()
  50. handler.ServeHTTP(recorder, req)
  51. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  52. assert.Nil(t, err)
  53. assert.Equal(t, http.StatusOK, recorder.Code)
  54. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  55. }
  56. func TestCryptionHandlerPostBadEncryption(t *testing.T) {
  57. var buf bytes.Buffer
  58. enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
  59. assert.Nil(t, err)
  60. buf.Write(enc)
  61. req := httptest.NewRequest(http.MethodPost, "/any", &buf)
  62. handler := CryptionHandler(aesKey)(nil)
  63. recorder := httptest.NewRecorder()
  64. handler.ServeHTTP(recorder, req)
  65. assert.Equal(t, http.StatusBadRequest, recorder.Code)
  66. }
  67. func TestCryptionHandlerWriteHeader(t *testing.T) {
  68. req := httptest.NewRequest(http.MethodGet, "/any", nil)
  69. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  70. w.WriteHeader(http.StatusServiceUnavailable)
  71. }))
  72. recorder := httptest.NewRecorder()
  73. handler.ServeHTTP(recorder, req)
  74. assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
  75. }
  76. func TestCryptionHandlerFlush(t *testing.T) {
  77. req := httptest.NewRequest(http.MethodGet, "/any", nil)
  78. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  79. w.Write([]byte(respText))
  80. flusher, ok := w.(http.Flusher)
  81. assert.True(t, ok)
  82. flusher.Flush()
  83. }))
  84. recorder := httptest.NewRecorder()
  85. handler.ServeHTTP(recorder, req)
  86. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  87. assert.Nil(t, err)
  88. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  89. }
  90. func TestCryptionHandler_Hijack(t *testing.T) {
  91. resp := httptest.NewRecorder()
  92. writer := newCryptionResponseWriter(resp)
  93. assert.NotPanics(t, func() {
  94. writer.Hijack()
  95. })
  96. writer = newCryptionResponseWriter(mockedHijackable{resp})
  97. assert.NotPanics(t, func() {
  98. writer.Hijack()
  99. })
  100. }
  101. func TestCryptionHandler_ContentTooLong(t *testing.T) {
  102. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  103. }))
  104. svr := httptest.NewServer(handler)
  105. defer svr.Close()
  106. body := make([]byte, maxBytes+1)
  107. rand.Read(body)
  108. req, err := http.NewRequest(http.MethodPost, svr.URL, bytes.NewReader(body))
  109. assert.Nil(t, err)
  110. resp, err := http.DefaultClient.Do(req)
  111. assert.Nil(t, err)
  112. assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
  113. }