breakerhandler_test.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package handler
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/http/httptest"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/wuntsong-org/go-zero-plus/core/stat"
  9. )
  10. func init() {
  11. stat.SetReporter(nil)
  12. }
  13. func TestBreakerHandlerAccept(t *testing.T) {
  14. metrics := stat.NewMetrics("unit-test")
  15. breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
  16. handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  17. w.Header().Set("X-Test", "test")
  18. _, err := w.Write([]byte("content"))
  19. assert.Nil(t, err)
  20. }))
  21. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  22. req.Header.Set("X-Test", "test")
  23. resp := httptest.NewRecorder()
  24. handler.ServeHTTP(resp, req)
  25. assert.Equal(t, http.StatusOK, resp.Code)
  26. assert.Equal(t, "test", resp.Header().Get("X-Test"))
  27. assert.Equal(t, "content", resp.Body.String())
  28. }
  29. func TestBreakerHandlerFail(t *testing.T) {
  30. metrics := stat.NewMetrics("unit-test")
  31. breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
  32. handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  33. w.WriteHeader(http.StatusBadGateway)
  34. }))
  35. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  36. resp := httptest.NewRecorder()
  37. handler.ServeHTTP(resp, req)
  38. assert.Equal(t, http.StatusBadGateway, resp.Code)
  39. }
  40. func TestBreakerHandler_4XX(t *testing.T) {
  41. metrics := stat.NewMetrics("unit-test")
  42. breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
  43. handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  44. w.WriteHeader(http.StatusBadRequest)
  45. }))
  46. for i := 0; i < 1000; i++ {
  47. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  48. resp := httptest.NewRecorder()
  49. handler.ServeHTTP(resp, req)
  50. }
  51. const tries = 100
  52. var pass int
  53. for i := 0; i < tries; i++ {
  54. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  55. resp := httptest.NewRecorder()
  56. handler.ServeHTTP(resp, req)
  57. if resp.Code == http.StatusBadRequest {
  58. pass++
  59. }
  60. }
  61. assert.Equal(t, tries, pass)
  62. }
  63. func TestBreakerHandlerReject(t *testing.T) {
  64. metrics := stat.NewMetrics("unit-test")
  65. breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
  66. handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  67. w.WriteHeader(http.StatusInternalServerError)
  68. }))
  69. for i := 0; i < 1000; i++ {
  70. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  71. resp := httptest.NewRecorder()
  72. handler.ServeHTTP(resp, req)
  73. }
  74. var drops int
  75. for i := 0; i < 100; i++ {
  76. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  77. resp := httptest.NewRecorder()
  78. handler.ServeHTTP(resp, req)
  79. if resp.Code == http.StatusServiceUnavailable {
  80. drops++
  81. }
  82. }
  83. assert.True(t, drops >= 80, fmt.Sprintf("expected to be greater than 80, but got %d", drops))
  84. }