maxconnshandler_test.go 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. package handler
  2. import (
  3. "io/ioutil"
  4. "log"
  5. "net/http"
  6. "net/http/httptest"
  7. "sync"
  8. "testing"
  9. "zero/core/lang"
  10. "github.com/stretchr/testify/assert"
  11. )
  12. const conns = 4
  13. func init() {
  14. log.SetOutput(ioutil.Discard)
  15. }
  16. func TestMaxConnsHandler(t *testing.T) {
  17. var waitGroup sync.WaitGroup
  18. waitGroup.Add(conns)
  19. done := make(chan lang.PlaceholderType)
  20. defer close(done)
  21. maxConns := MaxConns(conns)
  22. handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  23. waitGroup.Done()
  24. <-done
  25. }))
  26. for i := 0; i < conns; i++ {
  27. go func() {
  28. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  29. handler.ServeHTTP(httptest.NewRecorder(), req)
  30. }()
  31. }
  32. waitGroup.Wait()
  33. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  34. resp := httptest.NewRecorder()
  35. handler.ServeHTTP(resp, req)
  36. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  37. }
  38. func TestWithoutMaxConnsHandler(t *testing.T) {
  39. const (
  40. key = "block"
  41. value = "1"
  42. )
  43. var waitGroup sync.WaitGroup
  44. waitGroup.Add(conns)
  45. done := make(chan lang.PlaceholderType)
  46. defer close(done)
  47. maxConns := MaxConns(0)
  48. handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  49. val := r.Header.Get(key)
  50. if val == value {
  51. waitGroup.Done()
  52. <-done
  53. }
  54. }))
  55. for i := 0; i < conns; i++ {
  56. go func() {
  57. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  58. req.Header.Set(key, value)
  59. handler.ServeHTTP(httptest.NewRecorder(), req)
  60. }()
  61. }
  62. waitGroup.Wait()
  63. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  64. resp := httptest.NewRecorder()
  65. handler.ServeHTTP(resp, req)
  66. assert.Equal(t, http.StatusOK, resp.Code)
  67. }