فهرست منبع

refactor: remove duplicated code (#2705)

Kevin Wan 2 سال پیش
والد
کامیت
7a75dce465
1فایلهای تغییر یافته به همراه51 افزوده شده و 64 حذف شده
  1. 51 64
      rest/httpx/responses.go

+ 51 - 64
rest/httpx/responses.go

@@ -3,6 +3,7 @@ package httpx
 import (
 import (
 	"context"
 	"context"
 	"encoding/json"
 	"encoding/json"
+	"fmt"
 	"net/http"
 	"net/http"
 	"sync"
 	"sync"
 
 
@@ -13,8 +14,8 @@ import (
 
 
 var (
 var (
 	errorHandler    func(error) (int, interface{})
 	errorHandler    func(error) (int, interface{})
-	lock            sync.RWMutex
 	errorHandlerCtx func(context.Context, error) (int, interface{})
 	errorHandlerCtx func(context.Context, error) (int, interface{})
+	lock            sync.RWMutex
 )
 )
 
 
 // Error writes err into w.
 // Error writes err into w.
@@ -23,32 +24,26 @@ func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter,
 	handler := errorHandler
 	handler := errorHandler
 	lock.RUnlock()
 	lock.RUnlock()
 
 
-	if handler == nil {
-		if len(fns) > 0 {
-			fns[0](w, err)
-		} else if errcode.IsGrpcError(err) {
-			// don't unwrap error and get status.Message(),
-			// it hides the rpc error headers.
-			http.Error(w, err.Error(), errcode.CodeFromGrpcError(err))
-		} else {
-			http.Error(w, err.Error(), http.StatusBadRequest)
-		}
+	doHandleError(w, err, handler, WriteJson, fns...)
+}
 
 
-		return
-	}
+// ErrorCtx writes err into w.
+func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
+	fns ...func(w http.ResponseWriter, err error)) {
+	lock.RLock()
+	handlerCtx := errorHandlerCtx
+	lock.RUnlock()
 
 
-	code, body := handler(err)
-	if body == nil {
-		w.WriteHeader(code)
-		return
+	var handler func(error) (int, interface{})
+	if handlerCtx != nil {
+		handler = func(err error) (int, interface{}) {
+			return handlerCtx(ctx, err)
+		}
 	}
 	}
-
-	e, ok := body.(error)
-	if ok {
-		http.Error(w, e.Error(), code)
-	} else {
-		WriteJson(w, code, body)
+	writeJson := func(w http.ResponseWriter, code int, v interface{}) {
+		WriteJsonCtx(ctx, w, code, v)
 	}
 	}
+	doHandleError(w, err, handler, writeJson, fns...)
 }
 }
 
 
 // Ok writes HTTP 200 OK into w.
 // Ok writes HTTP 200 OK into w.
@@ -61,6 +56,11 @@ func OkJson(w http.ResponseWriter, v interface{}) {
 	WriteJson(w, http.StatusOK, v)
 	WriteJson(w, http.StatusOK, v)
 }
 }
 
 
+// OkJsonCtx writes v into w with 200 OK.
+func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) {
+	WriteJsonCtx(ctx, w, http.StatusOK, v)
+}
+
 // SetErrorHandler sets the error handler, which is called on calling Error.
 // SetErrorHandler sets the error handler, which is called on calling Error.
 func SetErrorHandler(handler func(error) (int, interface{})) {
 func SetErrorHandler(handler func(error) (int, interface{})) {
 	lock.Lock()
 	lock.Lock()
@@ -68,37 +68,35 @@ func SetErrorHandler(handler func(error) (int, interface{})) {
 	errorHandler = handler
 	errorHandler = handler
 }
 }
 
 
+// SetErrorHandlerCtx sets the error handler, which is called on calling Error.
+func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) {
+	lock.Lock()
+	defer lock.Unlock()
+	errorHandlerCtx = handlerCtx
+}
+
 // WriteJson writes v as json string into w with code.
 // WriteJson writes v as json string into w with code.
 func WriteJson(w http.ResponseWriter, code int, v interface{}) {
 func WriteJson(w http.ResponseWriter, code int, v interface{}) {
-	bs, err := json.Marshal(v)
-	if err != nil {
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
+	if err := doWriteJson(w, code, v); err != nil {
+		logx.Error(err)
 	}
 	}
+}
 
 
-	w.Header().Set(ContentType, header.JsonContentType)
-	w.WriteHeader(code)
-
-	if n, err := w.Write(bs); err != nil {
-		// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
-		// so it's ignored here.
-		if err != http.ErrHandlerTimeout {
-			logx.Errorf("write response failed, error: %s", err)
-		}
-	} else if n < len(bs) {
-		logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
+// WriteJsonCtx writes v as json string into w with code.
+func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) {
+	if err := doWriteJson(w, code, v); err != nil {
+		logx.WithContext(ctx).Error(err)
 	}
 	}
 }
 }
 
 
-// Error writes err into w.
-func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
-	lock.RLock()
-	handlerCtx := errorHandlerCtx
-	lock.RUnlock()
-
-	if handlerCtx == nil {
+func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, interface{}),
+	writeJson func(w http.ResponseWriter, code int, v interface{}),
+	fns ...func(w http.ResponseWriter, err error)) {
+	if handler == nil {
 		if len(fns) > 0 {
 		if len(fns) > 0 {
-			fns[0](w, err)
+			for _, fn := range fns {
+				fn(w, err)
+			}
 		} else if errcode.IsGrpcError(err) {
 		} else if errcode.IsGrpcError(err) {
 			// don't unwrap error and get status.Message(),
 			// don't unwrap error and get status.Message(),
 			// it hides the rpc error headers.
 			// it hides the rpc error headers.
@@ -110,7 +108,7 @@ func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func
 		return
 		return
 	}
 	}
 
 
-	code, body := handlerCtx(ctx, err)
+	code, body := handler(err)
 	if body == nil {
 	if body == nil {
 		w.WriteHeader(code)
 		w.WriteHeader(code)
 		return
 		return
@@ -120,21 +118,15 @@ func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func
 	if ok {
 	if ok {
 		http.Error(w, e.Error(), code)
 		http.Error(w, e.Error(), code)
 	} else {
 	} else {
-		WriteJsonCtx(ctx, w, code, body)
+		writeJson(w, code, body)
 	}
 	}
 }
 }
 
 
-// OkJson writes v into w with 200 OK.
-func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) {
-	WriteJsonCtx(ctx, w, http.StatusOK, v)
-}
-
-// WriteJson writes v as json string into w with code.
-func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) {
+func doWriteJson(w http.ResponseWriter, code int, v interface{}) error {
 	bs, err := json.Marshal(v)
 	bs, err := json.Marshal(v)
 	if err != nil {
 	if err != nil {
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
+		return fmt.Errorf("marshal json failed, error: %w", err)
 	}
 	}
 
 
 	w.Header().Set(ContentType, header.JsonContentType)
 	w.Header().Set(ContentType, header.JsonContentType)
@@ -144,16 +136,11 @@ func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interf
 		// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
 		// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
 		// so it's ignored here.
 		// so it's ignored here.
 		if err != http.ErrHandlerTimeout {
 		if err != http.ErrHandlerTimeout {
-			logx.WithContext(ctx).Errorf("write response failed, error: %s", err)
+			return fmt.Errorf("write response failed, error: %w", err)
 		}
 		}
 	} else if n < len(bs) {
 	} else if n < len(bs) {
-		logx.WithContext(ctx).Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
+		return fmt.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
 	}
 	}
-}
 
 
-// SetErrorHandler sets the error handler, which is called on calling Error.
-func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) {
-	lock.Lock()
-	defer lock.Unlock()
-	errorHandlerCtx = handlerCtx
+	return nil
 }
 }