maxconnshandler_test.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. package handler
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "sync"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/wuntsong-org/go-zero-plus/core/lang"
  9. )
  10. const conns = 4
  11. func TestMaxConnsHandler(t *testing.T) {
  12. var waitGroup sync.WaitGroup
  13. waitGroup.Add(conns)
  14. done := make(chan lang.PlaceholderType)
  15. defer close(done)
  16. maxConns := MaxConnsHandler(conns)
  17. handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  18. waitGroup.Done()
  19. <-done
  20. }))
  21. for i := 0; i < conns; i++ {
  22. go func() {
  23. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  24. handler.ServeHTTP(httptest.NewRecorder(), req)
  25. }()
  26. }
  27. waitGroup.Wait()
  28. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  29. resp := httptest.NewRecorder()
  30. handler.ServeHTTP(resp, req)
  31. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  32. }
  33. func TestWithoutMaxConnsHandler(t *testing.T) {
  34. const (
  35. key = "block"
  36. value = "1"
  37. )
  38. var waitGroup sync.WaitGroup
  39. waitGroup.Add(conns)
  40. done := make(chan lang.PlaceholderType)
  41. defer close(done)
  42. maxConns := MaxConnsHandler(0)
  43. handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  44. val := r.Header.Get(key)
  45. if val == value {
  46. waitGroup.Done()
  47. <-done
  48. }
  49. }))
  50. for i := 0; i < conns; i++ {
  51. go func() {
  52. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  53. req.Header.Set(key, value)
  54. handler.ServeHTTP(httptest.NewRecorder(), req)
  55. }()
  56. }
  57. waitGroup.Wait()
  58. req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
  59. resp := httptest.NewRecorder()
  60. handler.ServeHTTP(resp, req)
  61. assert.Equal(t, http.StatusOK, resp.Code)
  62. }