responses.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package httpx
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "sync"
  8. "github.com/zeromicro/go-zero/core/logx"
  9. "github.com/zeromicro/go-zero/rest/internal/errcode"
  10. "github.com/zeromicro/go-zero/rest/internal/header"
  11. )
  12. var (
  13. errorHandler func(context.Context, error) (int, any)
  14. errorLock sync.RWMutex
  15. okHandler func(context.Context, any) any
  16. okLock sync.RWMutex
  17. )
  18. // Error writes err into w.
  19. func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
  20. doHandleError(w, err, buildErrorHandler(context.Background()), WriteJson, fns...)
  21. }
  22. // ErrorCtx writes err into w.
  23. func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
  24. fns ...func(w http.ResponseWriter, err error)) {
  25. writeJson := func(w http.ResponseWriter, code int, v any) {
  26. WriteJsonCtx(ctx, w, code, v)
  27. }
  28. doHandleError(w, err, buildErrorHandler(ctx), writeJson, fns...)
  29. }
  30. // Ok writes HTTP 200 OK into w.
  31. func Ok(w http.ResponseWriter) {
  32. w.WriteHeader(http.StatusOK)
  33. }
  34. // OkJson writes v into w with 200 OK.
  35. func OkJson(w http.ResponseWriter, v any) {
  36. okLock.RLock()
  37. handler := okHandler
  38. okLock.RUnlock()
  39. if handler != nil {
  40. v = handler(context.Background(), v)
  41. }
  42. WriteJson(w, http.StatusOK, v)
  43. }
  44. // OkJsonCtx writes v into w with 200 OK.
  45. func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
  46. okLock.RLock()
  47. handlerCtx := okHandler
  48. okLock.RUnlock()
  49. if handlerCtx != nil {
  50. v = handlerCtx(ctx, v)
  51. }
  52. WriteJsonCtx(ctx, w, http.StatusOK, v)
  53. }
  54. // SetErrorHandler sets the error handler, which is called on calling Error.
  55. // Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
  56. // Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
  57. func SetErrorHandler(handler func(error) (int, any)) {
  58. errorLock.Lock()
  59. defer errorLock.Unlock()
  60. errorHandler = func(_ context.Context, err error) (int, any) {
  61. return handler(err)
  62. }
  63. }
  64. // SetErrorHandlerCtx sets the error handler, which is called on calling Error.
  65. // Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
  66. // Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
  67. func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) {
  68. errorLock.Lock()
  69. defer errorLock.Unlock()
  70. errorHandler = handlerCtx
  71. }
  72. // SetOkHandler sets the response handler, which is called on calling OkJson and OkJsonCtx.
  73. func SetOkHandler(handler func(context.Context, any) any) {
  74. okLock.Lock()
  75. defer okLock.Unlock()
  76. okHandler = handler
  77. }
  78. // WriteJson writes v as json string into w with code.
  79. func WriteJson(w http.ResponseWriter, code int, v any) {
  80. if err := doWriteJson(w, code, v); err != nil {
  81. logx.Error(err)
  82. }
  83. }
  84. // WriteJsonCtx writes v as json string into w with code.
  85. func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v any) {
  86. if err := doWriteJson(w, code, v); err != nil {
  87. logx.WithContext(ctx).Error(err)
  88. }
  89. }
  90. func buildErrorHandler(ctx context.Context) func(error) (int, any) {
  91. errorLock.RLock()
  92. handlerCtx := errorHandler
  93. errorLock.RUnlock()
  94. var handler func(error) (int, any)
  95. if handlerCtx != nil {
  96. handler = func(err error) (int, any) {
  97. return handlerCtx(ctx, err)
  98. }
  99. }
  100. return handler
  101. }
  102. func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any),
  103. writeJson func(w http.ResponseWriter, code int, v any),
  104. fns ...func(w http.ResponseWriter, err error)) {
  105. if handler == nil {
  106. if len(fns) > 0 {
  107. for _, fn := range fns {
  108. fn(w, err)
  109. }
  110. } else if errcode.IsGrpcError(err) {
  111. // don't unwrap error and get status.Message(),
  112. // it hides the rpc error headers.
  113. http.Error(w, err.Error(), errcode.CodeFromGrpcError(err))
  114. } else {
  115. http.Error(w, err.Error(), http.StatusBadRequest)
  116. }
  117. return
  118. }
  119. code, body := handler(err)
  120. if body == nil {
  121. w.WriteHeader(code)
  122. return
  123. }
  124. e, ok := body.(error)
  125. if ok {
  126. http.Error(w, e.Error(), code)
  127. } else {
  128. writeJson(w, code, body)
  129. }
  130. }
  131. func doWriteJson(w http.ResponseWriter, code int, v any) error {
  132. bs, err := json.Marshal(v)
  133. if err != nil {
  134. http.Error(w, err.Error(), http.StatusInternalServerError)
  135. return fmt.Errorf("marshal json failed, error: %w", err)
  136. }
  137. w.Header().Set(ContentType, header.JsonContentType)
  138. w.WriteHeader(code)
  139. if n, err := w.Write(bs); err != nil {
  140. // http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
  141. // so it's ignored here.
  142. if err != http.ErrHandlerTimeout {
  143. return fmt.Errorf("write response failed, error: %w", err)
  144. }
  145. } else if n < len(bs) {
  146. return fmt.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
  147. }
  148. return nil
  149. }