authhandler.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. package handler
  2. import (
  3. "context"
  4. "net/http"
  5. "net/http/httputil"
  6. "zero/core/logx"
  7. "zero/ngin/internal"
  8. "github.com/dgrijalva/jwt-go"
  9. )
  10. const (
  11. jwtAudience = "aud"
  12. jwtExpire = "exp"
  13. jwtId = "jti"
  14. jwtIssueAt = "iat"
  15. jwtIssuer = "iss"
  16. jwtNotBefore = "nbf"
  17. jwtSubject = "sub"
  18. )
  19. type (
  20. AuthorizeOptions struct {
  21. PrevSecret string
  22. Callback UnauthorizedCallback
  23. }
  24. UnauthorizedCallback func(w http.ResponseWriter, r *http.Request, err error)
  25. AuthorizeOption func(opts *AuthorizeOptions)
  26. )
  27. func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.Handler {
  28. var authOpts AuthorizeOptions
  29. for _, opt := range opts {
  30. opt(&authOpts)
  31. }
  32. parser := internal.NewTokenParser()
  33. return func(next http.Handler) http.Handler {
  34. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  35. token, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
  36. if err != nil {
  37. unauthorized(w, r, err, authOpts.Callback)
  38. return
  39. }
  40. if !token.Valid {
  41. unauthorized(w, r, err, authOpts.Callback)
  42. return
  43. }
  44. claims, ok := token.Claims.(jwt.MapClaims)
  45. if !ok {
  46. unauthorized(w, r, err, authOpts.Callback)
  47. return
  48. }
  49. ctx := r.Context()
  50. for k, v := range claims {
  51. switch k {
  52. case jwtAudience, jwtExpire, jwtId, jwtIssueAt, jwtIssuer, jwtNotBefore, jwtSubject:
  53. // ignore the standard claims
  54. default:
  55. ctx = context.WithValue(ctx, k, v)
  56. }
  57. }
  58. next.ServeHTTP(w, r.WithContext(ctx))
  59. })
  60. }
  61. }
  62. func WithPrevSecret(secret string) AuthorizeOption {
  63. return func(opts *AuthorizeOptions) {
  64. opts.PrevSecret = secret
  65. }
  66. }
  67. func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption {
  68. return func(opts *AuthorizeOptions) {
  69. opts.Callback = callback
  70. }
  71. }
  72. func detailAuthLog(r *http.Request, reason string) {
  73. // discard dump error, only for debug purpose
  74. details, _ := httputil.DumpRequest(r, true)
  75. logx.Errorf("authorize failed: %s\n=> %+v", reason, string(details))
  76. }
  77. func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) {
  78. writer := newGuardedResponseWriter(w)
  79. detailAuthLog(r, err.Error())
  80. if callback != nil {
  81. callback(writer, r, err)
  82. }
  83. writer.WriteHeader(http.StatusUnauthorized)
  84. }
  85. type guardedResponseWriter struct {
  86. writer http.ResponseWriter
  87. wroteHeader bool
  88. }
  89. func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
  90. return &guardedResponseWriter{
  91. writer: w,
  92. }
  93. }
  94. func (grw *guardedResponseWriter) Header() http.Header {
  95. return grw.writer.Header()
  96. }
  97. func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
  98. return grw.writer.Write(body)
  99. }
  100. func (grw *guardedResponseWriter) WriteHeader(statusCode int) {
  101. if grw.wroteHeader {
  102. return
  103. }
  104. grw.wroteHeader = true
  105. grw.writer.WriteHeader(statusCode)
  106. }