123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- package httpx
- import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
- "sync"
- "github.com/wuntsong-org/go-zero-plus/core/logx"
- "github.com/wuntsong-org/go-zero-plus/rest/internal/errcode"
- "github.com/wuntsong-org/go-zero-plus/rest/internal/header"
- )
- var (
- errorHandler func(context.Context, error) (int, any)
- errorLock sync.RWMutex
- okHandler func(context.Context, any) any
- okLock sync.RWMutex
- )
- // Error writes err into w.
- func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
- doHandleError(w, err, buildErrorHandler(context.Background()), WriteJson, fns...)
- }
- // ErrorCtx writes err into w.
- func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
- fns ...func(w http.ResponseWriter, err error)) {
- writeJson := func(w http.ResponseWriter, code int, v any) {
- WriteJsonCtx(ctx, w, code, v)
- }
- doHandleError(w, err, buildErrorHandler(ctx), writeJson, fns...)
- }
- // Ok writes HTTP 200 OK into w.
- func Ok(w http.ResponseWriter) {
- w.WriteHeader(http.StatusOK)
- }
- // OkJson writes v into w with 200 OK.
- func OkJson(w http.ResponseWriter, v any) {
- okLock.RLock()
- handler := okHandler
- okLock.RUnlock()
- if handler != nil {
- v = handler(context.Background(), v)
- }
- WriteJson(w, http.StatusOK, v)
- }
- // OkJsonCtx writes v into w with 200 OK.
- func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
- okLock.RLock()
- handlerCtx := okHandler
- okLock.RUnlock()
- if handlerCtx != nil {
- v = handlerCtx(ctx, v)
- }
- WriteJsonCtx(ctx, w, http.StatusOK, v)
- }
- // SetErrorHandler sets the error handler, which is called on calling Error.
- // Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
- // Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
- func SetErrorHandler(handler func(error) (int, any)) {
- errorLock.Lock()
- defer errorLock.Unlock()
- errorHandler = func(_ context.Context, err error) (int, any) {
- return handler(err)
- }
- }
- // SetErrorHandlerCtx sets the error handler, which is called on calling Error.
- // Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
- // Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
- func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) {
- errorLock.Lock()
- defer errorLock.Unlock()
- errorHandler = handlerCtx
- }
- // SetOkHandler sets the response handler, which is called on calling OkJson and OkJsonCtx.
- func SetOkHandler(handler func(context.Context, any) any) {
- okLock.Lock()
- defer okLock.Unlock()
- okHandler = handler
- }
- // WriteJson writes v as json string into w with code.
- func WriteJson(w http.ResponseWriter, code int, v any) {
- if err := doWriteJson(w, code, v); err != nil {
- logx.Error(err)
- }
- }
- // WriteJsonCtx writes v as json string into w with code.
- func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v any) {
- if err := doWriteJson(w, code, v); err != nil {
- logx.WithContext(ctx).Error(err)
- }
- }
- func buildErrorHandler(ctx context.Context) func(error) (int, any) {
- errorLock.RLock()
- handlerCtx := errorHandler
- errorLock.RUnlock()
- var handler func(error) (int, any)
- if handlerCtx != nil {
- handler = func(err error) (int, any) {
- return handlerCtx(ctx, err)
- }
- }
- return handler
- }
- func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any),
- writeJson func(w http.ResponseWriter, code int, v any),
- fns ...func(w http.ResponseWriter, err error)) {
- if handler == nil {
- if len(fns) > 0 {
- for _, fn := range fns {
- fn(w, err)
- }
- } else if errcode.IsGrpcError(err) {
- // don't unwrap error and get status.Message(),
- // it hides the rpc error headers.
- http.Error(w, err.Error(), errcode.CodeFromGrpcError(err))
- } else {
- http.Error(w, err.Error(), http.StatusBadRequest)
- }
- return
- }
- code, body := handler(err)
- if body == nil {
- w.WriteHeader(code)
- return
- }
- switch v := body.(type) {
- case error:
- http.Error(w, v.Error(), code)
- default:
- writeJson(w, code, body)
- }
- }
- func doWriteJson(w http.ResponseWriter, code int, v any) error {
- bs, err := json.Marshal(v)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return fmt.Errorf("marshal json failed, error: %w", err)
- }
- w.Header().Set(ContentType, header.JsonContentType)
- w.WriteHeader(code)
- if n, err := w.Write(bs); err != nil {
- // http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
- // so it's ignored here.
- if !errors.Is(err, http.ErrHandlerTimeout) {
- return fmt.Errorf("write response failed, error: %w", err)
- }
- } else if n < len(bs) {
- return fmt.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
- }
- return nil
- }
|