contentsecurityhandler.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. package handler
  2. import (
  3. "net/http"
  4. "time"
  5. "github.com/wuntsong-org/go-zero-plus/core/codec"
  6. "github.com/wuntsong-org/go-zero-plus/core/logx"
  7. "github.com/wuntsong-org/go-zero-plus/rest/httpx"
  8. "github.com/wuntsong-org/go-zero-plus/rest/internal/security"
  9. )
  10. const contentSecurity = "X-Content-Security"
  11. // UnsignedCallback defines the method of the unsigned callback.
  12. type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int)
  13. // ContentSecurityHandler returns a middleware to verify content security.
  14. func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration,
  15. strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler {
  16. return LimitContentSecurityHandler(maxBytes, decrypters, tolerance, strict, callbacks...)
  17. }
  18. // LimitContentSecurityHandler returns a middleware to verify content security.
  19. func LimitContentSecurityHandler(limitBytes int64, decrypters map[string]codec.RsaDecrypter,
  20. tolerance time.Duration, strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler {
  21. if len(callbacks) == 0 {
  22. callbacks = append(callbacks, handleVerificationFailure)
  23. }
  24. return func(next http.Handler) http.Handler {
  25. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  26. switch r.Method {
  27. case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut:
  28. header, err := security.ParseContentSecurity(decrypters, r)
  29. if err != nil {
  30. logx.Errorf("Signature parse failed, X-Content-Security: %s, error: %s",
  31. r.Header.Get(contentSecurity), err.Error())
  32. executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks)
  33. } else if code := security.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass {
  34. logx.Errorf("Signature verification failed, X-Content-Security: %s",
  35. r.Header.Get(contentSecurity))
  36. executeCallbacks(w, r, next, strict, code, callbacks)
  37. } else if r.ContentLength > 0 && header.Encrypted() {
  38. LimitCryptionHandler(limitBytes, header.Key)(next).ServeHTTP(w, r)
  39. } else {
  40. next.ServeHTTP(w, r)
  41. }
  42. default:
  43. next.ServeHTTP(w, r)
  44. }
  45. })
  46. }
  47. }
  48. func executeCallbacks(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool,
  49. code int, callbacks []UnsignedCallback) {
  50. for _, callback := range callbacks {
  51. callback(w, r, next, strict, code)
  52. }
  53. }
  54. func handleVerificationFailure(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
  55. if strict {
  56. w.WriteHeader(http.StatusForbidden)
  57. } else {
  58. next.ServeHTTP(w, r)
  59. }
  60. }