responses_test.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package httpx
  2. import (
  3. "errors"
  4. "net/http"
  5. "strings"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/zeromicro/go-zero/core/logx"
  9. "google.golang.org/grpc/codes"
  10. "google.golang.org/grpc/status"
  11. )
  12. type message struct {
  13. Name string `json:"name"`
  14. }
  15. func init() {
  16. logx.Disable()
  17. }
  18. func TestError(t *testing.T) {
  19. const (
  20. body = "foo"
  21. wrappedBody = `"foo"`
  22. )
  23. tests := []struct {
  24. name string
  25. input string
  26. errorHandler func(error) (int, interface{})
  27. expectHasBody bool
  28. expectBody string
  29. expectCode int
  30. }{
  31. {
  32. name: "default error handler",
  33. input: body,
  34. expectHasBody: true,
  35. expectBody: body,
  36. expectCode: http.StatusBadRequest,
  37. },
  38. {
  39. name: "customized error handler return string",
  40. input: body,
  41. errorHandler: func(err error) (int, interface{}) {
  42. return http.StatusForbidden, err.Error()
  43. },
  44. expectHasBody: true,
  45. expectBody: wrappedBody,
  46. expectCode: http.StatusForbidden,
  47. },
  48. {
  49. name: "customized error handler return error",
  50. input: body,
  51. errorHandler: func(err error) (int, interface{}) {
  52. return http.StatusForbidden, err
  53. },
  54. expectHasBody: true,
  55. expectBody: body,
  56. expectCode: http.StatusForbidden,
  57. },
  58. {
  59. name: "customized error handler return nil",
  60. input: body,
  61. errorHandler: func(err error) (int, interface{}) {
  62. return http.StatusForbidden, nil
  63. },
  64. expectHasBody: false,
  65. expectBody: "",
  66. expectCode: http.StatusForbidden,
  67. },
  68. }
  69. for _, test := range tests {
  70. t.Run(test.name, func(t *testing.T) {
  71. w := tracedResponseWriter{
  72. headers: make(map[string][]string),
  73. }
  74. if test.errorHandler != nil {
  75. lock.RLock()
  76. prev := errorHandler
  77. lock.RUnlock()
  78. SetErrorHandler(test.errorHandler)
  79. defer func() {
  80. lock.Lock()
  81. errorHandler = prev
  82. lock.Unlock()
  83. }()
  84. }
  85. Error(&w, errors.New(test.input))
  86. assert.Equal(t, test.expectCode, w.code)
  87. assert.Equal(t, test.expectHasBody, w.hasBody)
  88. assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
  89. })
  90. }
  91. }
  92. func TestErrorWithGrpcError(t *testing.T) {
  93. w := tracedResponseWriter{
  94. headers: make(map[string][]string),
  95. }
  96. Error(&w, status.Error(codes.Unavailable, "foo"))
  97. assert.Equal(t, http.StatusServiceUnavailable, w.code)
  98. assert.True(t, w.hasBody)
  99. assert.True(t, strings.Contains(w.builder.String(), "foo"))
  100. }
  101. func TestErrorWithHandler(t *testing.T) {
  102. w := tracedResponseWriter{
  103. headers: make(map[string][]string),
  104. }
  105. Error(&w, errors.New("foo"), func(w http.ResponseWriter, err error) {
  106. http.Error(w, err.Error(), 499)
  107. })
  108. assert.Equal(t, 499, w.code)
  109. assert.True(t, w.hasBody)
  110. assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
  111. }
  112. func TestOk(t *testing.T) {
  113. w := tracedResponseWriter{
  114. headers: make(map[string][]string),
  115. }
  116. Ok(&w)
  117. assert.Equal(t, http.StatusOK, w.code)
  118. }
  119. func TestOkJson(t *testing.T) {
  120. w := tracedResponseWriter{
  121. headers: make(map[string][]string),
  122. }
  123. msg := message{Name: "anyone"}
  124. OkJson(&w, msg)
  125. assert.Equal(t, http.StatusOK, w.code)
  126. assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
  127. }
  128. func TestWriteJsonTimeout(t *testing.T) {
  129. // only log it and ignore
  130. w := tracedResponseWriter{
  131. headers: make(map[string][]string),
  132. err: http.ErrHandlerTimeout,
  133. }
  134. msg := message{Name: "anyone"}
  135. WriteJson(&w, http.StatusOK, msg)
  136. assert.Equal(t, http.StatusOK, w.code)
  137. }
  138. func TestWriteJsonError(t *testing.T) {
  139. // only log it and ignore
  140. w := tracedResponseWriter{
  141. headers: make(map[string][]string),
  142. err: errors.New("foo"),
  143. }
  144. msg := message{Name: "anyone"}
  145. WriteJson(&w, http.StatusOK, msg)
  146. assert.Equal(t, http.StatusOK, w.code)
  147. }
  148. func TestWriteJsonLessWritten(t *testing.T) {
  149. w := tracedResponseWriter{
  150. headers: make(map[string][]string),
  151. lessWritten: true,
  152. }
  153. msg := message{Name: "anyone"}
  154. WriteJson(&w, http.StatusOK, msg)
  155. assert.Equal(t, http.StatusOK, w.code)
  156. }
  157. func TestWriteJsonMarshalFailed(t *testing.T) {
  158. w := tracedResponseWriter{
  159. headers: make(map[string][]string),
  160. }
  161. WriteJson(&w, http.StatusOK, map[string]interface{}{
  162. "Data": complex(0, 0),
  163. })
  164. assert.Equal(t, http.StatusInternalServerError, w.code)
  165. }
  166. type tracedResponseWriter struct {
  167. headers map[string][]string
  168. builder strings.Builder
  169. hasBody bool
  170. code int
  171. lessWritten bool
  172. wroteHeader bool
  173. err error
  174. }
  175. func (w *tracedResponseWriter) Header() http.Header {
  176. return w.headers
  177. }
  178. func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) {
  179. if w.err != nil {
  180. return 0, w.err
  181. }
  182. n, err = w.builder.Write(bytes)
  183. if w.lessWritten {
  184. n -= 1
  185. }
  186. w.hasBody = true
  187. return
  188. }
  189. func (w *tracedResponseWriter) WriteHeader(code int) {
  190. if w.wroteHeader {
  191. return
  192. }
  193. w.wroteHeader = true
  194. w.code = code
  195. }