chain.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. package chain
  2. // This is a modified version of https://github.com/justinas/alice
  3. // The original code is licensed under the MIT license.
  4. // It's modified for couple reasons:
  5. // - Added the Chain interface
  6. // - Added support for the Chain.Prepend(...) method
  7. import "net/http"
  8. type (
  9. // Chain defines a chain of middleware.
  10. Chain interface {
  11. Append(middlewares ...Middleware) Chain
  12. Prepend(middlewares ...Middleware) Chain
  13. Then(h http.Handler) http.Handler
  14. ThenFunc(fn http.HandlerFunc) http.Handler
  15. }
  16. // Middleware is an HTTP middleware.
  17. Middleware func(http.Handler) http.Handler
  18. // chain acts as a list of http.Handler middlewares.
  19. // chain is effectively immutable:
  20. // once created, it will always hold
  21. // the same set of middlewares in the same order.
  22. chain struct {
  23. middlewares []Middleware
  24. }
  25. )
  26. // New creates a new Chain, memorizing the given list of middleware middlewares.
  27. // New serves no other function, middlewares are only called upon a call to Then() or ThenFunc().
  28. func New(middlewares ...Middleware) Chain {
  29. return chain{middlewares: append(([]Middleware)(nil), middlewares...)}
  30. }
  31. // Append extends a chain, adding the specified middlewares as the last ones in the request flow.
  32. //
  33. // c := chain.New(m1, m2)
  34. // c.Append(m3, m4)
  35. // // requests in c go m1 -> m2 -> m3 -> m4
  36. func (c chain) Append(middlewares ...Middleware) Chain {
  37. return chain{middlewares: join(c.middlewares, middlewares)}
  38. }
  39. // Prepend extends a chain by adding the specified chain as the first one in the request flow.
  40. //
  41. // c := chain.New(m3, m4)
  42. // c1 := chain.New(m1, m2)
  43. // c.Prepend(c1)
  44. // // requests in c go m1 -> m2 -> m3 -> m4
  45. func (c chain) Prepend(middlewares ...Middleware) Chain {
  46. return chain{middlewares: join(middlewares, c.middlewares)}
  47. }
  48. // Then chains the middleware and returns the final http.Handler.
  49. //
  50. // New(m1, m2, m3).Then(h)
  51. //
  52. // is equivalent to:
  53. //
  54. // m1(m2(m3(h)))
  55. //
  56. // When the request comes in, it will be passed to m1, then m2, then m3
  57. // and finally, the given handler
  58. // (assuming every middleware calls the following one).
  59. //
  60. // A chain can be safely reused by calling Then() several times.
  61. //
  62. // stdStack := chain.New(ratelimitHandler, csrfHandler)
  63. // indexPipe = stdStack.Then(indexHandler)
  64. // authPipe = stdStack.Then(authHandler)
  65. //
  66. // Note that middlewares are called on every call to Then() or ThenFunc()
  67. // and thus several instances of the same middleware will be created
  68. // when a chain is reused in this way.
  69. // For proper middleware, this should cause no problems.
  70. //
  71. // Then() treats nil as http.DefaultServeMux.
  72. func (c chain) Then(h http.Handler) http.Handler {
  73. if h == nil {
  74. h = http.DefaultServeMux
  75. }
  76. for i := range c.middlewares {
  77. h = c.middlewares[len(c.middlewares)-1-i](h)
  78. }
  79. return h
  80. }
  81. // ThenFunc works identically to Then, but takes
  82. // a HandlerFunc instead of a Handler.
  83. //
  84. // The following two statements are equivalent:
  85. //
  86. // c.Then(http.HandlerFunc(fn))
  87. // c.ThenFunc(fn)
  88. //
  89. // ThenFunc provides all the guarantees of Then.
  90. func (c chain) ThenFunc(fn http.HandlerFunc) http.Handler {
  91. // This nil check cannot be removed due to the "nil is not nil" common mistake in Go.
  92. // Required due to: https://stackoverflow.com/questions/33426977/how-to-golang-check-a-variable-is-nil
  93. if fn == nil {
  94. return c.Then(nil)
  95. }
  96. return c.Then(fn)
  97. }
  98. func join(a, b []Middleware) []Middleware {
  99. mids := make([]Middleware, 0, len(a)+len(b))
  100. mids = append(mids, a...)
  101. mids = append(mids, b...)
  102. return mids
  103. }