cryptionhandler_test.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. package handler
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "io"
  6. "math/rand"
  7. "net/http"
  8. "net/http/httptest"
  9. "testing"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/zeromicro/go-zero/core/codec"
  12. )
  13. const (
  14. reqText = "ping"
  15. respText = "pong"
  16. )
  17. var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
  18. func TestCryptionHandlerGet(t *testing.T) {
  19. req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
  20. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  21. _, err := w.Write([]byte(respText))
  22. w.Header().Set("X-Test", "test")
  23. assert.Nil(t, err)
  24. }))
  25. recorder := httptest.NewRecorder()
  26. handler.ServeHTTP(recorder, req)
  27. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  28. assert.Nil(t, err)
  29. assert.Equal(t, http.StatusOK, recorder.Code)
  30. assert.Equal(t, "test", recorder.Header().Get("X-Test"))
  31. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  32. }
  33. func TestCryptionHandlerPost(t *testing.T) {
  34. var buf bytes.Buffer
  35. enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
  36. assert.Nil(t, err)
  37. buf.WriteString(base64.StdEncoding.EncodeToString(enc))
  38. req := httptest.NewRequest(http.MethodPost, "/any", &buf)
  39. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  40. body, err := io.ReadAll(r.Body)
  41. assert.Nil(t, err)
  42. assert.Equal(t, reqText, string(body))
  43. w.Write([]byte(respText))
  44. }))
  45. recorder := httptest.NewRecorder()
  46. handler.ServeHTTP(recorder, req)
  47. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  48. assert.Nil(t, err)
  49. assert.Equal(t, http.StatusOK, recorder.Code)
  50. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  51. }
  52. func TestCryptionHandlerPostBadEncryption(t *testing.T) {
  53. var buf bytes.Buffer
  54. enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
  55. assert.Nil(t, err)
  56. buf.Write(enc)
  57. req := httptest.NewRequest(http.MethodPost, "/any", &buf)
  58. handler := CryptionHandler(aesKey)(nil)
  59. recorder := httptest.NewRecorder()
  60. handler.ServeHTTP(recorder, req)
  61. assert.Equal(t, http.StatusBadRequest, recorder.Code)
  62. }
  63. func TestCryptionHandlerWriteHeader(t *testing.T) {
  64. req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
  65. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  66. w.WriteHeader(http.StatusServiceUnavailable)
  67. }))
  68. recorder := httptest.NewRecorder()
  69. handler.ServeHTTP(recorder, req)
  70. assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
  71. }
  72. func TestCryptionHandlerFlush(t *testing.T) {
  73. req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
  74. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  75. w.Write([]byte(respText))
  76. flusher, ok := w.(http.Flusher)
  77. assert.True(t, ok)
  78. flusher.Flush()
  79. }))
  80. recorder := httptest.NewRecorder()
  81. handler.ServeHTTP(recorder, req)
  82. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  83. assert.Nil(t, err)
  84. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  85. }
  86. func TestCryptionHandler_Hijack(t *testing.T) {
  87. resp := httptest.NewRecorder()
  88. writer := newCryptionResponseWriter(resp)
  89. assert.NotPanics(t, func() {
  90. writer.Hijack()
  91. })
  92. writer = newCryptionResponseWriter(mockedHijackable{resp})
  93. assert.NotPanics(t, func() {
  94. writer.Hijack()
  95. })
  96. }
  97. func TestCryptionHandler_ContentTooLong(t *testing.T) {
  98. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  99. }))
  100. svr := httptest.NewServer(handler)
  101. defer svr.Close()
  102. body := make([]byte, maxBytes+1)
  103. rand.Read(body)
  104. req, err := http.NewRequest(http.MethodPost, svr.URL, bytes.NewReader(body))
  105. assert.Nil(t, err)
  106. resp, err := http.DefaultClient.Do(req)
  107. assert.Nil(t, err)
  108. assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
  109. }