responses.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package httpx
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "sync"
  9. "github.com/wuntsong-org/go-zero-plus/core/logx"
  10. "github.com/wuntsong-org/go-zero-plus/rest/internal/errcode"
  11. "github.com/wuntsong-org/go-zero-plus/rest/internal/header"
  12. )
  13. var (
  14. errorHandler func(context.Context, error) (int, any)
  15. errorLock sync.RWMutex
  16. okHandler func(context.Context, any) any
  17. okLock sync.RWMutex
  18. )
  19. // Error writes err into w.
  20. func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
  21. doHandleError(w, err, buildErrorHandler(context.Background()), WriteJson, fns...)
  22. }
  23. // ErrorCtx writes err into w.
  24. func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
  25. fns ...func(w http.ResponseWriter, err error)) {
  26. writeJson := func(w http.ResponseWriter, code int, v any) {
  27. WriteJsonCtx(ctx, w, code, v)
  28. }
  29. doHandleError(w, err, buildErrorHandler(ctx), writeJson, fns...)
  30. }
  31. // Ok writes HTTP 200 OK into w.
  32. func Ok(w http.ResponseWriter) {
  33. w.WriteHeader(http.StatusOK)
  34. }
  35. // OkJson writes v into w with 200 OK.
  36. func OkJson(w http.ResponseWriter, v any) {
  37. okLock.RLock()
  38. handler := okHandler
  39. okLock.RUnlock()
  40. if handler != nil {
  41. v = handler(context.Background(), v)
  42. }
  43. WriteJson(w, http.StatusOK, v)
  44. }
  45. // OkJsonCtx writes v into w with 200 OK.
  46. func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
  47. okLock.RLock()
  48. handlerCtx := okHandler
  49. okLock.RUnlock()
  50. if handlerCtx != nil {
  51. v = handlerCtx(ctx, v)
  52. }
  53. WriteJsonCtx(ctx, w, http.StatusOK, v)
  54. }
  55. // SetErrorHandler sets the error handler, which is called on calling Error.
  56. // Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
  57. // Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
  58. func SetErrorHandler(handler func(error) (int, any)) {
  59. errorLock.Lock()
  60. defer errorLock.Unlock()
  61. errorHandler = func(_ context.Context, err error) (int, any) {
  62. return handler(err)
  63. }
  64. }
  65. // SetErrorHandlerCtx sets the error handler, which is called on calling Error.
  66. // Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
  67. // Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
  68. func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) {
  69. errorLock.Lock()
  70. defer errorLock.Unlock()
  71. errorHandler = handlerCtx
  72. }
  73. // SetOkHandler sets the response handler, which is called on calling OkJson and OkJsonCtx.
  74. func SetOkHandler(handler func(context.Context, any) any) {
  75. okLock.Lock()
  76. defer okLock.Unlock()
  77. okHandler = handler
  78. }
  79. // WriteJson writes v as json string into w with code.
  80. func WriteJson(w http.ResponseWriter, code int, v any) {
  81. if err := doWriteJson(w, code, v); err != nil {
  82. logx.Error(err)
  83. }
  84. }
  85. // WriteJsonCtx writes v as json string into w with code.
  86. func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v any) {
  87. if err := doWriteJson(w, code, v); err != nil {
  88. logx.WithContext(ctx).Error(err)
  89. }
  90. }
  91. func buildErrorHandler(ctx context.Context) func(error) (int, any) {
  92. errorLock.RLock()
  93. handlerCtx := errorHandler
  94. errorLock.RUnlock()
  95. var handler func(error) (int, any)
  96. if handlerCtx != nil {
  97. handler = func(err error) (int, any) {
  98. return handlerCtx(ctx, err)
  99. }
  100. }
  101. return handler
  102. }
  103. func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any),
  104. writeJson func(w http.ResponseWriter, code int, v any),
  105. fns ...func(w http.ResponseWriter, err error)) {
  106. if handler == nil {
  107. if len(fns) > 0 {
  108. for _, fn := range fns {
  109. fn(w, err)
  110. }
  111. } else if errcode.IsGrpcError(err) {
  112. // don't unwrap error and get status.Message(),
  113. // it hides the rpc error headers.
  114. http.Error(w, err.Error(), errcode.CodeFromGrpcError(err))
  115. } else {
  116. http.Error(w, err.Error(), http.StatusBadRequest)
  117. }
  118. return
  119. }
  120. code, body := handler(err)
  121. if body == nil {
  122. w.WriteHeader(code)
  123. return
  124. }
  125. switch v := body.(type) {
  126. case error:
  127. http.Error(w, v.Error(), code)
  128. default:
  129. writeJson(w, code, body)
  130. }
  131. }
  132. func doWriteJson(w http.ResponseWriter, code int, v any) error {
  133. bs, err := json.Marshal(v)
  134. if err != nil {
  135. http.Error(w, err.Error(), http.StatusInternalServerError)
  136. return fmt.Errorf("marshal json failed, error: %w", err)
  137. }
  138. w.Header().Set(ContentType, header.JsonContentType)
  139. w.WriteHeader(code)
  140. if n, err := w.Write(bs); err != nil {
  141. // http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
  142. // so it's ignored here.
  143. if !errors.Is(err, http.ErrHandlerTimeout) {
  144. return fmt.Errorf("write response failed, error: %w", err)
  145. }
  146. } else if n < len(bs) {
  147. return fmt.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
  148. }
  149. return nil
  150. }