chain.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. package chain
  2. // This is a modified version of https://github.com/justinas/alice
  3. // The original code is licensed under the MIT license.
  4. // It's modified for couple reasons:
  5. // - Added the Chain interface
  6. // - Added support for the Chain.Prepend(...) method
  7. import "net/http"
  8. type (
  9. // Chain defines a chain of middleware.
  10. Chain interface {
  11. Append(middlewares ...Middleware) Chain
  12. Prepend(middlewares ...Middleware) Chain
  13. Then(h http.Handler) http.Handler
  14. ThenFunc(fn http.HandlerFunc) http.Handler
  15. }
  16. // Middleware is an HTTP middleware.
  17. Middleware func(http.Handler) http.Handler
  18. // chain acts as a list of http.Handler middlewares.
  19. // chain is effectively immutable:
  20. // once created, it will always hold
  21. // the same set of middlewares in the same order.
  22. chain struct {
  23. middlewares []Middleware
  24. }
  25. )
  26. // New creates a new Chain, memorizing the given list of middleware middlewares.
  27. // New serves no other function, middlewares are only called upon a call to Then() or ThenFunc().
  28. func New(middlewares ...Middleware) Chain {
  29. return chain{middlewares: append(([]Middleware)(nil), middlewares...)}
  30. }
  31. // Append extends a chain, adding the specified middlewares as the last ones in the request flow.
  32. //
  33. // c := chain.New(m1, m2)
  34. // c.Append(m3, m4)
  35. // // requests in c go m1 -> m2 -> m3 -> m4
  36. func (c chain) Append(middlewares ...Middleware) Chain {
  37. return chain{middlewares: join(c.middlewares, middlewares)}
  38. }
  39. // Prepend extends a chain by adding the specified chain as the first one in the request flow.
  40. //
  41. // c := chain.New(m3, m4)
  42. // c1 := chain.New(m1, m2)
  43. // c.Prepend(c1)
  44. // // requests in c go m1 -> m2 -> m3 -> m4
  45. func (c chain) Prepend(middlewares ...Middleware) Chain {
  46. return chain{middlewares: join(middlewares, c.middlewares)}
  47. }
  48. // Then chains the middleware and returns the final http.Handler.
  49. // New(m1, m2, m3).Then(h)
  50. // is equivalent to:
  51. // m1(m2(m3(h)))
  52. // When the request comes in, it will be passed to m1, then m2, then m3
  53. // and finally, the given handler
  54. // (assuming every middleware calls the following one).
  55. //
  56. // A chain can be safely reused by calling Then() several times.
  57. // stdStack := chain.New(ratelimitHandler, csrfHandler)
  58. // indexPipe = stdStack.Then(indexHandler)
  59. // authPipe = stdStack.Then(authHandler)
  60. // Note that middlewares are called on every call to Then() or ThenFunc()
  61. // and thus several instances of the same middleware will be created
  62. // when a chain is reused in this way.
  63. // For proper middleware, this should cause no problems.
  64. //
  65. // Then() treats nil as http.DefaultServeMux.
  66. func (c chain) Then(h http.Handler) http.Handler {
  67. if h == nil {
  68. h = http.DefaultServeMux
  69. }
  70. for i := range c.middlewares {
  71. h = c.middlewares[len(c.middlewares)-1-i](h)
  72. }
  73. return h
  74. }
  75. // ThenFunc works identically to Then, but takes
  76. // a HandlerFunc instead of a Handler.
  77. //
  78. // The following two statements are equivalent:
  79. // c.Then(http.HandlerFunc(fn))
  80. // c.ThenFunc(fn)
  81. //
  82. // ThenFunc provides all the guarantees of Then.
  83. func (c chain) ThenFunc(fn http.HandlerFunc) http.Handler {
  84. // This nil check cannot be removed due to the "nil is not nil" common mistake in Go.
  85. // Required due to: https://stackoverflow.com/questions/33426977/how-to-golang-check-a-variable-is-nil
  86. if fn == nil {
  87. return c.Then(nil)
  88. }
  89. return c.Then(fn)
  90. }
  91. func join(a, b []Middleware) []Middleware {
  92. mids := make([]Middleware, 0, len(a)+len(b))
  93. mids = append(mids, a...)
  94. mids = append(mids, b...)
  95. return mids
  96. }