responses_test.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. package httpx
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "strings"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/zeromicro/go-zero/core/logx"
  11. "google.golang.org/grpc/codes"
  12. "google.golang.org/grpc/status"
  13. )
  14. type message struct {
  15. Name string `json:"name"`
  16. }
  17. func init() {
  18. logx.Disable()
  19. }
  20. func TestError(t *testing.T) {
  21. const (
  22. body = "foo"
  23. wrappedBody = `"foo"`
  24. )
  25. tests := []struct {
  26. name string
  27. input string
  28. errorHandler func(error) (int, any)
  29. expectHasBody bool
  30. expectBody string
  31. expectCode int
  32. }{
  33. {
  34. name: "default error handler",
  35. input: body,
  36. expectHasBody: true,
  37. expectBody: body,
  38. expectCode: http.StatusBadRequest,
  39. },
  40. {
  41. name: "customized error handler return string",
  42. input: body,
  43. errorHandler: func(err error) (int, any) {
  44. return http.StatusForbidden, err.Error()
  45. },
  46. expectHasBody: true,
  47. expectBody: wrappedBody,
  48. expectCode: http.StatusForbidden,
  49. },
  50. {
  51. name: "customized error handler return error",
  52. input: body,
  53. errorHandler: func(err error) (int, any) {
  54. return http.StatusForbidden, err
  55. },
  56. expectHasBody: true,
  57. expectBody: body,
  58. expectCode: http.StatusForbidden,
  59. },
  60. {
  61. name: "customized error handler return nil",
  62. input: body,
  63. errorHandler: func(err error) (int, any) {
  64. return http.StatusForbidden, nil
  65. },
  66. expectHasBody: false,
  67. expectBody: "",
  68. expectCode: http.StatusForbidden,
  69. },
  70. }
  71. for _, test := range tests {
  72. t.Run(test.name, func(t *testing.T) {
  73. w := tracedResponseWriter{
  74. headers: make(map[string][]string),
  75. }
  76. if test.errorHandler != nil {
  77. errorLock.RLock()
  78. prev := errorHandler
  79. errorLock.RUnlock()
  80. SetErrorHandler(test.errorHandler)
  81. defer func() {
  82. errorLock.Lock()
  83. errorHandler = prev
  84. errorLock.Unlock()
  85. }()
  86. }
  87. Error(&w, errors.New(test.input))
  88. assert.Equal(t, test.expectCode, w.code)
  89. assert.Equal(t, test.expectHasBody, w.hasBody)
  90. assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
  91. })
  92. }
  93. }
  94. func TestErrorWithGrpcError(t *testing.T) {
  95. w := tracedResponseWriter{
  96. headers: make(map[string][]string),
  97. }
  98. Error(&w, status.Error(codes.Unavailable, "foo"))
  99. assert.Equal(t, http.StatusServiceUnavailable, w.code)
  100. assert.True(t, w.hasBody)
  101. assert.True(t, strings.Contains(w.builder.String(), "foo"))
  102. }
  103. func TestErrorWithHandler(t *testing.T) {
  104. w := tracedResponseWriter{
  105. headers: make(map[string][]string),
  106. }
  107. Error(&w, errors.New("foo"), func(w http.ResponseWriter, err error) {
  108. http.Error(w, err.Error(), 499)
  109. })
  110. assert.Equal(t, 499, w.code)
  111. assert.True(t, w.hasBody)
  112. assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
  113. }
  114. func TestOk(t *testing.T) {
  115. w := tracedResponseWriter{
  116. headers: make(map[string][]string),
  117. }
  118. Ok(&w)
  119. assert.Equal(t, http.StatusOK, w.code)
  120. }
  121. func TestOkJson(t *testing.T) {
  122. t.Run("no handler", func(t *testing.T) {
  123. w := tracedResponseWriter{
  124. headers: make(map[string][]string),
  125. }
  126. msg := message{Name: "anyone"}
  127. OkJson(&w, msg)
  128. assert.Equal(t, http.StatusOK, w.code)
  129. assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
  130. })
  131. t.Run("with handler", func(t *testing.T) {
  132. respLock.RLock()
  133. prev := respHandler
  134. respLock.RUnlock()
  135. t.Cleanup(func() {
  136. respLock.Lock()
  137. respHandler = prev
  138. respLock.Unlock()
  139. })
  140. SetResponseHandler(func(_ context.Context, v interface{}) any {
  141. return fmt.Sprintf("hello %s", v.(message).Name)
  142. })
  143. w := tracedResponseWriter{
  144. headers: make(map[string][]string),
  145. }
  146. msg := message{Name: "anyone"}
  147. OkJson(&w, msg)
  148. assert.Equal(t, http.StatusOK, w.code)
  149. assert.Equal(t, `"hello anyone"`, w.builder.String())
  150. })
  151. }
  152. func TestOkJsonCtx(t *testing.T) {
  153. t.Run("no handler", func(t *testing.T) {
  154. w := tracedResponseWriter{
  155. headers: make(map[string][]string),
  156. }
  157. msg := message{Name: "anyone"}
  158. OkJsonCtx(context.Background(), &w, msg)
  159. assert.Equal(t, http.StatusOK, w.code)
  160. assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
  161. })
  162. t.Run("with handler", func(t *testing.T) {
  163. respLock.RLock()
  164. prev := respHandler
  165. respLock.RUnlock()
  166. t.Cleanup(func() {
  167. respLock.Lock()
  168. respHandler = prev
  169. respLock.Unlock()
  170. })
  171. SetResponseHandler(func(_ context.Context, v interface{}) any {
  172. return fmt.Sprintf("hello %s", v.(message).Name)
  173. })
  174. w := tracedResponseWriter{
  175. headers: make(map[string][]string),
  176. }
  177. msg := message{Name: "anyone"}
  178. OkJsonCtx(context.Background(), &w, msg)
  179. assert.Equal(t, http.StatusOK, w.code)
  180. assert.Equal(t, `"hello anyone"`, w.builder.String())
  181. })
  182. }
  183. func TestWriteJsonTimeout(t *testing.T) {
  184. // only log it and ignore
  185. w := tracedResponseWriter{
  186. headers: make(map[string][]string),
  187. err: http.ErrHandlerTimeout,
  188. }
  189. msg := message{Name: "anyone"}
  190. WriteJson(&w, http.StatusOK, msg)
  191. assert.Equal(t, http.StatusOK, w.code)
  192. }
  193. func TestWriteJsonError(t *testing.T) {
  194. // only log it and ignore
  195. w := tracedResponseWriter{
  196. headers: make(map[string][]string),
  197. err: errors.New("foo"),
  198. }
  199. msg := message{Name: "anyone"}
  200. WriteJson(&w, http.StatusOK, msg)
  201. assert.Equal(t, http.StatusOK, w.code)
  202. }
  203. func TestWriteJsonLessWritten(t *testing.T) {
  204. w := tracedResponseWriter{
  205. headers: make(map[string][]string),
  206. lessWritten: true,
  207. }
  208. msg := message{Name: "anyone"}
  209. WriteJson(&w, http.StatusOK, msg)
  210. assert.Equal(t, http.StatusOK, w.code)
  211. }
  212. func TestWriteJsonMarshalFailed(t *testing.T) {
  213. w := tracedResponseWriter{
  214. headers: make(map[string][]string),
  215. }
  216. WriteJson(&w, http.StatusOK, map[string]any{
  217. "Data": complex(0, 0),
  218. })
  219. assert.Equal(t, http.StatusInternalServerError, w.code)
  220. }
  221. type tracedResponseWriter struct {
  222. headers map[string][]string
  223. builder strings.Builder
  224. hasBody bool
  225. code int
  226. lessWritten bool
  227. wroteHeader bool
  228. err error
  229. }
  230. func (w *tracedResponseWriter) Header() http.Header {
  231. return w.headers
  232. }
  233. func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) {
  234. if w.err != nil {
  235. return 0, w.err
  236. }
  237. n, err = w.builder.Write(bytes)
  238. if w.lessWritten {
  239. n--
  240. }
  241. w.hasBody = true
  242. return
  243. }
  244. func (w *tracedResponseWriter) WriteHeader(code int) {
  245. if w.wroteHeader {
  246. return
  247. }
  248. w.wroteHeader = true
  249. w.code = code
  250. }
  251. func TestErrorCtx(t *testing.T) {
  252. const (
  253. body = "foo"
  254. wrappedBody = `"foo"`
  255. )
  256. tests := []struct {
  257. name string
  258. input string
  259. errorHandlerCtx func(context.Context, error) (int, any)
  260. expectHasBody bool
  261. expectBody string
  262. expectCode int
  263. }{
  264. {
  265. name: "default error handler",
  266. input: body,
  267. expectHasBody: true,
  268. expectBody: body,
  269. expectCode: http.StatusBadRequest,
  270. },
  271. {
  272. name: "customized error handler return string",
  273. input: body,
  274. errorHandlerCtx: func(ctx context.Context, err error) (int, any) {
  275. return http.StatusForbidden, err.Error()
  276. },
  277. expectHasBody: true,
  278. expectBody: wrappedBody,
  279. expectCode: http.StatusForbidden,
  280. },
  281. {
  282. name: "customized error handler return error",
  283. input: body,
  284. errorHandlerCtx: func(ctx context.Context, err error) (int, any) {
  285. return http.StatusForbidden, err
  286. },
  287. expectHasBody: true,
  288. expectBody: body,
  289. expectCode: http.StatusForbidden,
  290. },
  291. {
  292. name: "customized error handler return nil",
  293. input: body,
  294. errorHandlerCtx: func(context.Context, error) (int, any) {
  295. return http.StatusForbidden, nil
  296. },
  297. expectHasBody: false,
  298. expectBody: "",
  299. expectCode: http.StatusForbidden,
  300. },
  301. }
  302. for _, test := range tests {
  303. t.Run(test.name, func(t *testing.T) {
  304. w := tracedResponseWriter{
  305. headers: make(map[string][]string),
  306. }
  307. if test.errorHandlerCtx != nil {
  308. errorLock.RLock()
  309. prev := errorHandler
  310. errorLock.RUnlock()
  311. SetErrorHandlerCtx(test.errorHandlerCtx)
  312. defer func() {
  313. errorLock.Lock()
  314. test.errorHandlerCtx = prev
  315. errorLock.Unlock()
  316. }()
  317. }
  318. ErrorCtx(context.Background(), &w, errors.New(test.input))
  319. assert.Equal(t, test.expectCode, w.code)
  320. assert.Equal(t, test.expectHasBody, w.hasBody)
  321. assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
  322. })
  323. }
  324. // The current handler is a global event,Set default values to avoid impacting subsequent unit tests
  325. SetErrorHandlerCtx(nil)
  326. }
  327. func TestErrorWithGrpcErrorCtx(t *testing.T) {
  328. w := tracedResponseWriter{
  329. headers: make(map[string][]string),
  330. }
  331. ErrorCtx(context.Background(), &w, status.Error(codes.Unavailable, "foo"))
  332. assert.Equal(t, http.StatusServiceUnavailable, w.code)
  333. assert.True(t, w.hasBody)
  334. assert.True(t, strings.Contains(w.builder.String(), "foo"))
  335. }
  336. func TestErrorWithHandlerCtx(t *testing.T) {
  337. w := tracedResponseWriter{
  338. headers: make(map[string][]string),
  339. }
  340. ErrorCtx(context.Background(), &w, errors.New("foo"), func(w http.ResponseWriter, err error) {
  341. http.Error(w, err.Error(), 499)
  342. })
  343. assert.Equal(t, 499, w.code)
  344. assert.True(t, w.hasBody)
  345. assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
  346. }
  347. func TestWriteJsonCtxMarshalFailed(t *testing.T) {
  348. w := tracedResponseWriter{
  349. headers: make(map[string][]string),
  350. }
  351. WriteJsonCtx(context.Background(), &w, http.StatusOK, map[string]any{
  352. "Data": complex(0, 0),
  353. })
  354. assert.Equal(t, http.StatusInternalServerError, w.code)
  355. }