1
0

patrouter.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. package router
  2. import (
  3. "errors"
  4. "net/http"
  5. "path"
  6. "strings"
  7. "github.com/wuntsong-org/go-zero-plus/core/search"
  8. "github.com/wuntsong-org/go-zero-plus/rest/httpx"
  9. "github.com/wuntsong-org/go-zero-plus/rest/pathvar"
  10. )
  11. const (
  12. allowHeader = "Allow"
  13. allowMethodSeparator = ", "
  14. )
  15. var (
  16. // ErrInvalidMethod is an error that indicates not a valid http method.
  17. ErrInvalidMethod = errors.New("not a valid http method")
  18. // ErrInvalidPath is an error that indicates path is not start with /.
  19. ErrInvalidPath = errors.New("path must begin with '/'")
  20. )
  21. type patRouter struct {
  22. trees map[string]*search.Tree
  23. notFound http.Handler
  24. notAllowed http.Handler
  25. optionsHandle http.Handler
  26. middleware httpx.MiddlewareFunc
  27. }
  28. // NewRouter returns a httpx.Router.
  29. func NewRouter() httpx.Router {
  30. return &patRouter{
  31. trees: make(map[string]*search.Tree),
  32. }
  33. }
  34. func (pr *patRouter) Handle(method, reqPath string, handler http.Handler) error {
  35. if !validMethod(method) {
  36. return ErrInvalidMethod
  37. }
  38. if len(reqPath) == 0 || reqPath[0] != '/' {
  39. return ErrInvalidPath
  40. }
  41. cleanPath := path.Clean(reqPath)
  42. tree, ok := pr.trees[method]
  43. if ok {
  44. return tree.Add(cleanPath, handler)
  45. }
  46. tree = search.NewTree()
  47. pr.trees[method] = tree
  48. return tree.Add(cleanPath, handler)
  49. }
  50. func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  51. if pr.middleware != nil {
  52. pr.middleware(http.HandlerFunc(pr.serveHTTP)).ServeHTTP(w, r)
  53. } else {
  54. pr.serveHTTP(w, r)
  55. }
  56. }
  57. func (pr *patRouter) serveHTTP(w http.ResponseWriter, r *http.Request) {
  58. if r.Method == http.MethodOptions && pr.optionsHandle != nil {
  59. pr.optionsHandle.ServeHTTP(w, r)
  60. return
  61. }
  62. reqPath := path.Clean(r.URL.Path)
  63. if tree, ok := pr.trees[r.Method]; ok {
  64. if result, ok := tree.Search(reqPath); ok {
  65. if len(result.Params) > 0 {
  66. r = pathvar.WithVars(r, result.Params)
  67. }
  68. result.Item.(http.Handler).ServeHTTP(w, r)
  69. return
  70. }
  71. }
  72. allows, ok := pr.methodsAllowed(r.Method, reqPath)
  73. if !ok {
  74. pr.handleNotFound(w, r)
  75. return
  76. }
  77. if pr.notAllowed != nil {
  78. pr.notAllowed.ServeHTTP(w, r)
  79. } else {
  80. w.Header().Set(allowHeader, allows)
  81. w.WriteHeader(http.StatusMethodNotAllowed)
  82. }
  83. }
  84. func (pr *patRouter) SetNotFoundHandler(handler http.Handler) {
  85. pr.notFound = handler
  86. }
  87. func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) {
  88. pr.notAllowed = handler
  89. }
  90. func (pr *patRouter) SetOptionsHandler(handler http.Handler) {
  91. f := func(w http.ResponseWriter, r *http.Request) {
  92. if r.Method != http.MethodOptions {
  93. pr.notAllowed.ServeHTTP(w, r)
  94. return
  95. }
  96. handler.ServeHTTP(w, r)
  97. }
  98. pr.optionsHandle = http.HandlerFunc(f)
  99. }
  100. func (pr *patRouter) SetMiddleware(middleware httpx.MiddlewareFunc) {
  101. pr.middleware = middleware
  102. }
  103. func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
  104. if pr.notFound != nil {
  105. pr.notFound.ServeHTTP(w, r)
  106. } else {
  107. http.NotFound(w, r)
  108. }
  109. }
  110. func (pr *patRouter) methodsAllowed(method, path string) (string, bool) {
  111. var allows []string
  112. for treeMethod, tree := range pr.trees {
  113. if treeMethod == method {
  114. continue
  115. }
  116. _, ok := tree.Search(path)
  117. if ok {
  118. allows = append(allows, treeMethod)
  119. }
  120. }
  121. if len(allows) > 0 {
  122. return strings.Join(allows, allowMethodSeparator), true
  123. }
  124. return "", false
  125. }
  126. func validMethod(method string) bool {
  127. return method == http.MethodDelete || method == http.MethodGet ||
  128. method == http.MethodHead || method == http.MethodOptions ||
  129. method == http.MethodPatch || method == http.MethodPost ||
  130. method == http.MethodPut
  131. }