responses_test.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. package httpx
  2. import (
  3. "errors"
  4. "net/http"
  5. "strings"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/zeromicro/go-zero/core/logx"
  9. )
  10. type message struct {
  11. Name string `json:"name"`
  12. }
  13. func init() {
  14. logx.Disable()
  15. }
  16. func TestError(t *testing.T) {
  17. const (
  18. body = "foo"
  19. wrappedBody = `"foo"`
  20. )
  21. tests := []struct {
  22. name string
  23. input string
  24. errorHandler func(error) (int, interface{})
  25. expectHasBody bool
  26. expectBody string
  27. expectCode int
  28. }{
  29. {
  30. name: "default error handler",
  31. input: body,
  32. expectHasBody: true,
  33. expectBody: body,
  34. expectCode: http.StatusBadRequest,
  35. },
  36. {
  37. name: "customized error handler return string",
  38. input: body,
  39. errorHandler: func(err error) (int, interface{}) {
  40. return http.StatusForbidden, err.Error()
  41. },
  42. expectHasBody: true,
  43. expectBody: wrappedBody,
  44. expectCode: http.StatusForbidden,
  45. },
  46. {
  47. name: "customized error handler return error",
  48. input: body,
  49. errorHandler: func(err error) (int, interface{}) {
  50. return http.StatusForbidden, err
  51. },
  52. expectHasBody: true,
  53. expectBody: body,
  54. expectCode: http.StatusForbidden,
  55. },
  56. {
  57. name: "customized error handler return nil",
  58. input: body,
  59. errorHandler: func(err error) (int, interface{}) {
  60. return http.StatusForbidden, nil
  61. },
  62. expectHasBody: false,
  63. expectBody: "",
  64. expectCode: http.StatusForbidden,
  65. },
  66. }
  67. for _, test := range tests {
  68. t.Run(test.name, func(t *testing.T) {
  69. w := tracedResponseWriter{
  70. headers: make(map[string][]string),
  71. }
  72. if test.errorHandler != nil {
  73. lock.RLock()
  74. prev := errorHandler
  75. lock.RUnlock()
  76. SetErrorHandler(test.errorHandler)
  77. defer func() {
  78. lock.Lock()
  79. errorHandler = prev
  80. lock.Unlock()
  81. }()
  82. }
  83. Error(&w, errors.New(test.input))
  84. assert.Equal(t, test.expectCode, w.code)
  85. assert.Equal(t, test.expectHasBody, w.hasBody)
  86. assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
  87. })
  88. }
  89. }
  90. func TestErrorWithHandler(t *testing.T) {
  91. w := tracedResponseWriter{
  92. headers: make(map[string][]string),
  93. }
  94. Error(&w, errors.New("foo"), func(w http.ResponseWriter, err error) {
  95. http.Error(w, err.Error(), 499)
  96. })
  97. assert.Equal(t, 499, w.code)
  98. assert.True(t, w.hasBody)
  99. assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
  100. }
  101. func TestOk(t *testing.T) {
  102. w := tracedResponseWriter{
  103. headers: make(map[string][]string),
  104. }
  105. Ok(&w)
  106. assert.Equal(t, http.StatusOK, w.code)
  107. }
  108. func TestOkJson(t *testing.T) {
  109. w := tracedResponseWriter{
  110. headers: make(map[string][]string),
  111. }
  112. msg := message{Name: "anyone"}
  113. OkJson(&w, msg)
  114. assert.Equal(t, http.StatusOK, w.code)
  115. assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
  116. }
  117. func TestWriteJsonTimeout(t *testing.T) {
  118. // only log it and ignore
  119. w := tracedResponseWriter{
  120. headers: make(map[string][]string),
  121. err: http.ErrHandlerTimeout,
  122. }
  123. msg := message{Name: "anyone"}
  124. WriteJson(&w, http.StatusOK, msg)
  125. assert.Equal(t, http.StatusOK, w.code)
  126. }
  127. func TestWriteJsonError(t *testing.T) {
  128. // only log it and ignore
  129. w := tracedResponseWriter{
  130. headers: make(map[string][]string),
  131. err: errors.New("foo"),
  132. }
  133. msg := message{Name: "anyone"}
  134. WriteJson(&w, http.StatusOK, msg)
  135. assert.Equal(t, http.StatusOK, w.code)
  136. }
  137. func TestWriteJsonLessWritten(t *testing.T) {
  138. w := tracedResponseWriter{
  139. headers: make(map[string][]string),
  140. lessWritten: true,
  141. }
  142. msg := message{Name: "anyone"}
  143. WriteJson(&w, http.StatusOK, msg)
  144. assert.Equal(t, http.StatusOK, w.code)
  145. }
  146. func TestWriteJsonMarshalFailed(t *testing.T) {
  147. w := tracedResponseWriter{
  148. headers: make(map[string][]string),
  149. }
  150. WriteJson(&w, http.StatusOK, map[string]interface{}{
  151. "Data": complex(0, 0),
  152. })
  153. assert.Equal(t, http.StatusInternalServerError, w.code)
  154. }
  155. type tracedResponseWriter struct {
  156. headers map[string][]string
  157. builder strings.Builder
  158. hasBody bool
  159. code int
  160. lessWritten bool
  161. wroteHeader bool
  162. err error
  163. }
  164. func (w *tracedResponseWriter) Header() http.Header {
  165. return w.headers
  166. }
  167. func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) {
  168. if w.err != nil {
  169. return 0, w.err
  170. }
  171. n, err = w.builder.Write(bytes)
  172. if w.lessWritten {
  173. n -= 1
  174. }
  175. w.hasBody = true
  176. return
  177. }
  178. func (w *tracedResponseWriter) WriteHeader(code int) {
  179. if w.wroteHeader {
  180. return
  181. }
  182. w.wroteHeader = true
  183. w.code = code
  184. }