cryptionhandler_test.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package handler
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "encoding/base64"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "testing"
  10. "testing/iotest"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/zeromicro/go-zero/core/codec"
  13. )
  14. const (
  15. reqText = "ping"
  16. respText = "pong"
  17. )
  18. var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
  19. func TestCryptionHandlerGet(t *testing.T) {
  20. req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
  21. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  22. _, err := w.Write([]byte(respText))
  23. w.Header().Set("X-Test", "test")
  24. assert.Nil(t, err)
  25. }))
  26. recorder := httptest.NewRecorder()
  27. handler.ServeHTTP(recorder, req)
  28. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  29. assert.Nil(t, err)
  30. assert.Equal(t, http.StatusOK, recorder.Code)
  31. assert.Equal(t, "test", recorder.Header().Get("X-Test"))
  32. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  33. }
  34. func TestCryptionHandlerGet_badKey(t *testing.T) {
  35. req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
  36. handler := CryptionHandler(append(aesKey, aesKey...))(http.HandlerFunc(
  37. func(w http.ResponseWriter, r *http.Request) {
  38. _, err := w.Write([]byte(respText))
  39. w.Header().Set("X-Test", "test")
  40. assert.Nil(t, err)
  41. }))
  42. recorder := httptest.NewRecorder()
  43. handler.ServeHTTP(recorder, req)
  44. assert.Equal(t, http.StatusInternalServerError, recorder.Code)
  45. }
  46. func TestCryptionHandlerPost(t *testing.T) {
  47. var buf bytes.Buffer
  48. enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
  49. assert.Nil(t, err)
  50. buf.WriteString(base64.StdEncoding.EncodeToString(enc))
  51. req := httptest.NewRequest(http.MethodPost, "/any", &buf)
  52. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  53. body, err := io.ReadAll(r.Body)
  54. assert.Nil(t, err)
  55. assert.Equal(t, reqText, string(body))
  56. w.Write([]byte(respText))
  57. }))
  58. recorder := httptest.NewRecorder()
  59. handler.ServeHTTP(recorder, req)
  60. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  61. assert.Nil(t, err)
  62. assert.Equal(t, http.StatusOK, recorder.Code)
  63. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  64. }
  65. func TestCryptionHandlerPostBadEncryption(t *testing.T) {
  66. var buf bytes.Buffer
  67. enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
  68. assert.Nil(t, err)
  69. buf.Write(enc)
  70. req := httptest.NewRequest(http.MethodPost, "/any", &buf)
  71. handler := CryptionHandler(aesKey)(nil)
  72. recorder := httptest.NewRecorder()
  73. handler.ServeHTTP(recorder, req)
  74. assert.Equal(t, http.StatusBadRequest, recorder.Code)
  75. }
  76. func TestCryptionHandlerWriteHeader(t *testing.T) {
  77. req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
  78. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  79. w.WriteHeader(http.StatusServiceUnavailable)
  80. }))
  81. recorder := httptest.NewRecorder()
  82. handler.ServeHTTP(recorder, req)
  83. assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
  84. }
  85. func TestCryptionHandlerFlush(t *testing.T) {
  86. req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
  87. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  88. w.Write([]byte(respText))
  89. flusher, ok := w.(http.Flusher)
  90. assert.True(t, ok)
  91. flusher.Flush()
  92. }))
  93. recorder := httptest.NewRecorder()
  94. handler.ServeHTTP(recorder, req)
  95. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  96. assert.Nil(t, err)
  97. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  98. }
  99. func TestCryptionHandler_Hijack(t *testing.T) {
  100. resp := httptest.NewRecorder()
  101. writer := newCryptionResponseWriter(resp)
  102. assert.NotPanics(t, func() {
  103. writer.Hijack()
  104. })
  105. writer = newCryptionResponseWriter(mockedHijackable{resp})
  106. assert.NotPanics(t, func() {
  107. writer.Hijack()
  108. })
  109. }
  110. func TestCryptionHandler_ContentTooLong(t *testing.T) {
  111. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  112. }))
  113. svr := httptest.NewServer(handler)
  114. defer svr.Close()
  115. body := make([]byte, maxBytes+1)
  116. _, err := rand.Read(body)
  117. assert.NoError(t, err)
  118. req, err := http.NewRequest(http.MethodPost, svr.URL, bytes.NewReader(body))
  119. assert.Nil(t, err)
  120. resp, err := http.DefaultClient.Do(req)
  121. assert.Nil(t, err)
  122. assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
  123. }
  124. func TestCryptionHandler_BadBody(t *testing.T) {
  125. req, err := http.NewRequest(http.MethodPost, "/foo", iotest.ErrReader(io.ErrUnexpectedEOF))
  126. assert.Nil(t, err)
  127. err = decryptBody(maxBytes, aesKey, req)
  128. assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
  129. }
  130. func TestCryptionHandler_BadKey(t *testing.T) {
  131. var buf bytes.Buffer
  132. enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
  133. assert.Nil(t, err)
  134. buf.WriteString(base64.StdEncoding.EncodeToString(enc))
  135. req := httptest.NewRequest(http.MethodPost, "/any", &buf)
  136. err = decryptBody(maxBytes, append(aesKey, aesKey...), req)
  137. assert.Error(t, err)
  138. }