responses_test.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. package httpx
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "strings"
  7. "testing"
  8. "github.com/stretchr/testify/assert"
  9. "github.com/zeromicro/go-zero/core/logx"
  10. "google.golang.org/grpc/codes"
  11. "google.golang.org/grpc/status"
  12. )
  13. type message struct {
  14. Name string `json:"name"`
  15. }
  16. func init() {
  17. logx.Disable()
  18. }
  19. func TestError(t *testing.T) {
  20. const (
  21. body = "foo"
  22. wrappedBody = `"foo"`
  23. )
  24. tests := []struct {
  25. name string
  26. input string
  27. errorHandler func(error) (int, interface{})
  28. expectHasBody bool
  29. expectBody string
  30. expectCode int
  31. }{
  32. {
  33. name: "default error handler",
  34. input: body,
  35. expectHasBody: true,
  36. expectBody: body,
  37. expectCode: http.StatusBadRequest,
  38. },
  39. {
  40. name: "customized error handler return string",
  41. input: body,
  42. errorHandler: func(err error) (int, interface{}) {
  43. return http.StatusForbidden, err.Error()
  44. },
  45. expectHasBody: true,
  46. expectBody: wrappedBody,
  47. expectCode: http.StatusForbidden,
  48. },
  49. {
  50. name: "customized error handler return error",
  51. input: body,
  52. errorHandler: func(err error) (int, interface{}) {
  53. return http.StatusForbidden, err
  54. },
  55. expectHasBody: true,
  56. expectBody: body,
  57. expectCode: http.StatusForbidden,
  58. },
  59. {
  60. name: "customized error handler return nil",
  61. input: body,
  62. errorHandler: func(err error) (int, interface{}) {
  63. return http.StatusForbidden, nil
  64. },
  65. expectHasBody: false,
  66. expectBody: "",
  67. expectCode: http.StatusForbidden,
  68. },
  69. }
  70. for _, test := range tests {
  71. t.Run(test.name, func(t *testing.T) {
  72. w := tracedResponseWriter{
  73. headers: make(map[string][]string),
  74. }
  75. if test.errorHandler != nil {
  76. lock.RLock()
  77. prev := errorHandler
  78. lock.RUnlock()
  79. SetErrorHandler(test.errorHandler)
  80. defer func() {
  81. lock.Lock()
  82. errorHandler = prev
  83. lock.Unlock()
  84. }()
  85. }
  86. Error(&w, errors.New(test.input))
  87. assert.Equal(t, test.expectCode, w.code)
  88. assert.Equal(t, test.expectHasBody, w.hasBody)
  89. assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
  90. })
  91. }
  92. }
  93. func TestErrorWithGrpcError(t *testing.T) {
  94. w := tracedResponseWriter{
  95. headers: make(map[string][]string),
  96. }
  97. Error(&w, status.Error(codes.Unavailable, "foo"))
  98. assert.Equal(t, http.StatusServiceUnavailable, w.code)
  99. assert.True(t, w.hasBody)
  100. assert.True(t, strings.Contains(w.builder.String(), "foo"))
  101. }
  102. func TestErrorWithHandler(t *testing.T) {
  103. w := tracedResponseWriter{
  104. headers: make(map[string][]string),
  105. }
  106. Error(&w, errors.New("foo"), func(w http.ResponseWriter, err error) {
  107. http.Error(w, err.Error(), 499)
  108. })
  109. assert.Equal(t, 499, w.code)
  110. assert.True(t, w.hasBody)
  111. assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
  112. }
  113. func TestOk(t *testing.T) {
  114. w := tracedResponseWriter{
  115. headers: make(map[string][]string),
  116. }
  117. Ok(&w)
  118. assert.Equal(t, http.StatusOK, w.code)
  119. }
  120. func TestOkJson(t *testing.T) {
  121. w := tracedResponseWriter{
  122. headers: make(map[string][]string),
  123. }
  124. msg := message{Name: "anyone"}
  125. OkJson(&w, msg)
  126. assert.Equal(t, http.StatusOK, w.code)
  127. assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
  128. }
  129. func TestWriteJsonTimeout(t *testing.T) {
  130. // only log it and ignore
  131. w := tracedResponseWriter{
  132. headers: make(map[string][]string),
  133. err: http.ErrHandlerTimeout,
  134. }
  135. msg := message{Name: "anyone"}
  136. WriteJson(&w, http.StatusOK, msg)
  137. assert.Equal(t, http.StatusOK, w.code)
  138. }
  139. func TestWriteJsonError(t *testing.T) {
  140. // only log it and ignore
  141. w := tracedResponseWriter{
  142. headers: make(map[string][]string),
  143. err: errors.New("foo"),
  144. }
  145. msg := message{Name: "anyone"}
  146. WriteJson(&w, http.StatusOK, msg)
  147. assert.Equal(t, http.StatusOK, w.code)
  148. }
  149. func TestWriteJsonLessWritten(t *testing.T) {
  150. w := tracedResponseWriter{
  151. headers: make(map[string][]string),
  152. lessWritten: true,
  153. }
  154. msg := message{Name: "anyone"}
  155. WriteJson(&w, http.StatusOK, msg)
  156. assert.Equal(t, http.StatusOK, w.code)
  157. }
  158. func TestWriteJsonMarshalFailed(t *testing.T) {
  159. w := tracedResponseWriter{
  160. headers: make(map[string][]string),
  161. }
  162. WriteJson(&w, http.StatusOK, map[string]interface{}{
  163. "Data": complex(0, 0),
  164. })
  165. assert.Equal(t, http.StatusInternalServerError, w.code)
  166. }
  167. type tracedResponseWriter struct {
  168. headers map[string][]string
  169. builder strings.Builder
  170. hasBody bool
  171. code int
  172. lessWritten bool
  173. wroteHeader bool
  174. err error
  175. }
  176. func (w *tracedResponseWriter) Header() http.Header {
  177. return w.headers
  178. }
  179. func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) {
  180. if w.err != nil {
  181. return 0, w.err
  182. }
  183. n, err = w.builder.Write(bytes)
  184. if w.lessWritten {
  185. n--
  186. }
  187. w.hasBody = true
  188. return
  189. }
  190. func (w *tracedResponseWriter) WriteHeader(code int) {
  191. if w.wroteHeader {
  192. return
  193. }
  194. w.wroteHeader = true
  195. w.code = code
  196. }
  197. func TestErrorCtx(t *testing.T) {
  198. const (
  199. body = "foo"
  200. wrappedBody = `"foo"`
  201. )
  202. tests := []struct {
  203. name string
  204. input string
  205. errorHandlerCtx func(context.Context, error) (int, interface{})
  206. expectHasBody bool
  207. expectBody string
  208. expectCode int
  209. }{
  210. {
  211. name: "default error handler",
  212. input: body,
  213. expectHasBody: true,
  214. expectBody: body,
  215. expectCode: http.StatusBadRequest,
  216. },
  217. {
  218. name: "customized error handler return string",
  219. input: body,
  220. errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) {
  221. return http.StatusForbidden, err.Error()
  222. },
  223. expectHasBody: true,
  224. expectBody: wrappedBody,
  225. expectCode: http.StatusForbidden,
  226. },
  227. {
  228. name: "customized error handler return error",
  229. input: body,
  230. errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) {
  231. return http.StatusForbidden, err
  232. },
  233. expectHasBody: true,
  234. expectBody: body,
  235. expectCode: http.StatusForbidden,
  236. },
  237. {
  238. name: "customized error handler return nil",
  239. input: body,
  240. errorHandlerCtx: func(context.Context, error) (int, interface{}) {
  241. return http.StatusForbidden, nil
  242. },
  243. expectHasBody: false,
  244. expectBody: "",
  245. expectCode: http.StatusForbidden,
  246. },
  247. }
  248. for _, test := range tests {
  249. t.Run(test.name, func(t *testing.T) {
  250. w := tracedResponseWriter{
  251. headers: make(map[string][]string),
  252. }
  253. if test.errorHandlerCtx != nil {
  254. lock.RLock()
  255. prev := errorHandlerCtx
  256. lock.RUnlock()
  257. SetErrorHandlerCtx(test.errorHandlerCtx)
  258. defer func() {
  259. lock.Lock()
  260. test.errorHandlerCtx = prev
  261. lock.Unlock()
  262. }()
  263. }
  264. ErrorCtx(context.Background(), &w, errors.New(test.input))
  265. assert.Equal(t, test.expectCode, w.code)
  266. assert.Equal(t, test.expectHasBody, w.hasBody)
  267. assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
  268. })
  269. }
  270. //The current handler is a global event,Set default values to avoid impacting subsequent unit tests
  271. SetErrorHandlerCtx(nil)
  272. }
  273. func TestErrorWithGrpcErrorCtx(t *testing.T) {
  274. w := tracedResponseWriter{
  275. headers: make(map[string][]string),
  276. }
  277. ErrorCtx(context.Background(), &w, status.Error(codes.Unavailable, "foo"))
  278. assert.Equal(t, http.StatusServiceUnavailable, w.code)
  279. assert.True(t, w.hasBody)
  280. assert.True(t, strings.Contains(w.builder.String(), "foo"))
  281. }
  282. func TestErrorWithHandlerCtx(t *testing.T) {
  283. w := tracedResponseWriter{
  284. headers: make(map[string][]string),
  285. }
  286. ErrorCtx(context.Background(), &w, errors.New("foo"), func(w http.ResponseWriter, err error) {
  287. http.Error(w, err.Error(), 499)
  288. })
  289. assert.Equal(t, 499, w.code)
  290. assert.True(t, w.hasBody)
  291. assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
  292. }
  293. func TestWriteJsonCtxMarshalFailed(t *testing.T) {
  294. w := tracedResponseWriter{
  295. headers: make(map[string][]string),
  296. }
  297. WriteJsonCtx(context.Background(), &w, http.StatusOK, map[string]interface{}{
  298. "Data": complex(0, 0),
  299. })
  300. assert.Equal(t, http.StatusInternalServerError, w.code)
  301. }