responses.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package httpx
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "sync"
  8. "github.com/zeromicro/go-zero/core/logx"
  9. "github.com/zeromicro/go-zero/rest/internal/errcode"
  10. "github.com/zeromicro/go-zero/rest/internal/header"
  11. )
  12. var (
  13. errorHandler func(error) (int, any)
  14. errorHandlerCtx func(context.Context, error) (int, any)
  15. lock sync.RWMutex
  16. )
  17. // Error writes err into w.
  18. func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
  19. lock.RLock()
  20. handler := errorHandler
  21. lock.RUnlock()
  22. doHandleError(w, err, handler, WriteJson, fns...)
  23. }
  24. // ErrorCtx writes err into w.
  25. func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
  26. fns ...func(w http.ResponseWriter, err error)) {
  27. lock.RLock()
  28. handlerCtx := errorHandlerCtx
  29. lock.RUnlock()
  30. var handler func(error) (int, any)
  31. if handlerCtx != nil {
  32. handler = func(err error) (int, any) {
  33. return handlerCtx(ctx, err)
  34. }
  35. }
  36. writeJson := func(w http.ResponseWriter, code int, v any) {
  37. WriteJsonCtx(ctx, w, code, v)
  38. }
  39. doHandleError(w, err, handler, writeJson, fns...)
  40. }
  41. // Ok writes HTTP 200 OK into w.
  42. func Ok(w http.ResponseWriter) {
  43. w.WriteHeader(http.StatusOK)
  44. }
  45. // OkJson writes v into w with 200 OK.
  46. func OkJson(w http.ResponseWriter, v any) {
  47. WriteJson(w, http.StatusOK, v)
  48. }
  49. // OkJsonCtx writes v into w with 200 OK.
  50. func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
  51. WriteJsonCtx(ctx, w, http.StatusOK, v)
  52. }
  53. // SetErrorHandler sets the error handler, which is called on calling Error.
  54. func SetErrorHandler(handler func(error) (int, any)) {
  55. lock.Lock()
  56. defer lock.Unlock()
  57. errorHandler = handler
  58. }
  59. // SetErrorHandlerCtx sets the error handler, which is called on calling Error.
  60. func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) {
  61. lock.Lock()
  62. defer lock.Unlock()
  63. errorHandlerCtx = handlerCtx
  64. }
  65. // WriteJson writes v as json string into w with code.
  66. func WriteJson(w http.ResponseWriter, code int, v any) {
  67. if err := doWriteJson(w, code, v); err != nil {
  68. logx.Error(err)
  69. }
  70. }
  71. // WriteJsonCtx writes v as json string into w with code.
  72. func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v any) {
  73. if err := doWriteJson(w, code, v); err != nil {
  74. logx.WithContext(ctx).Error(err)
  75. }
  76. }
  77. func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any),
  78. writeJson func(w http.ResponseWriter, code int, v any),
  79. fns ...func(w http.ResponseWriter, err error)) {
  80. if handler == nil {
  81. if len(fns) > 0 {
  82. for _, fn := range fns {
  83. fn(w, err)
  84. }
  85. } else if errcode.IsGrpcError(err) {
  86. // don't unwrap error and get status.Message(),
  87. // it hides the rpc error headers.
  88. http.Error(w, err.Error(), errcode.CodeFromGrpcError(err))
  89. } else {
  90. http.Error(w, err.Error(), http.StatusBadRequest)
  91. }
  92. return
  93. }
  94. code, body := handler(err)
  95. if body == nil {
  96. w.WriteHeader(code)
  97. return
  98. }
  99. e, ok := body.(error)
  100. if ok {
  101. http.Error(w, e.Error(), code)
  102. } else {
  103. writeJson(w, code, body)
  104. }
  105. }
  106. func doWriteJson(w http.ResponseWriter, code int, v any) error {
  107. bs, err := json.Marshal(v)
  108. if err != nil {
  109. http.Error(w, err.Error(), http.StatusInternalServerError)
  110. return fmt.Errorf("marshal json failed, error: %w", err)
  111. }
  112. w.Header().Set(ContentType, header.JsonContentType)
  113. w.WriteHeader(code)
  114. if n, err := w.Write(bs); err != nil {
  115. // http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
  116. // so it's ignored here.
  117. if err != http.ErrHandlerTimeout {
  118. return fmt.Errorf("write response failed, error: %w", err)
  119. }
  120. } else if n < len(bs) {
  121. return fmt.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
  122. }
  123. return nil
  124. }