Quellcode durchsuchen

chore: refactor and add more tests (#3351)

Kevin Wan vor 1 Jahr
Ursprung
Commit
f998803131
2 geänderte Dateien mit 120 neuen und 63 gelöschten Zeilen
  1. 45 47
      rest/httpx/responses.go
  2. 75 16
      rest/httpx/responses_test.go

+ 45 - 47
rest/httpx/responses.go

@@ -13,40 +13,24 @@ import (
 )
 
 var (
-	errorHandler     func(error) (int, any)
-	errorHandlerCtx  func(context.Context, error) (int, any)
-	commonHandler    func(any) any
-	commonHandlerCtx func(context.Context, any) any
-	lock             sync.RWMutex
-	cLock            sync.RWMutex
+	errorHandler func(context.Context, error) (int, any)
+	errorLock    sync.RWMutex
+	respHandler  func(context.Context, any) any
+	respLock     sync.RWMutex
 )
 
 // Error writes err into w.
 func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
-	lock.RLock()
-	handler := errorHandler
-	lock.RUnlock()
-
-	doHandleError(w, err, handler, WriteJson, fns...)
+	doHandleError(w, err, buildErrorHandler(context.Background()), WriteJson, fns...)
 }
 
 // ErrorCtx writes err into w.
 func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
 	fns ...func(w http.ResponseWriter, err error)) {
-	lock.RLock()
-	handlerCtx := errorHandlerCtx
-	lock.RUnlock()
-
-	var handler func(error) (int, any)
-	if handlerCtx != nil {
-		handler = func(err error) (int, any) {
-			return handlerCtx(ctx, err)
-		}
-	}
 	writeJson := func(w http.ResponseWriter, code int, v any) {
 		WriteJsonCtx(ctx, w, code, v)
 	}
-	doHandleError(w, err, handler, writeJson, fns...)
+	doHandleError(w, err, buildErrorHandler(ctx), writeJson, fns...)
 }
 
 // Ok writes HTTP 200 OK into w.
@@ -56,20 +40,20 @@ func Ok(w http.ResponseWriter) {
 
 // OkJson writes v into w with 200 OK.
 func OkJson(w http.ResponseWriter, v any) {
-	cLock.RLock()
-	handler := commonHandler
-	cLock.RUnlock()
+	respLock.RLock()
+	handler := respHandler
+	respLock.RUnlock()
 	if handler != nil {
-		v = handler(v)
+		v = handler(context.Background(), v)
 	}
 	WriteJson(w, http.StatusOK, v)
 }
 
 // OkJsonCtx writes v into w with 200 OK.
 func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
-	cLock.RLock()
-	handlerCtx := commonHandlerCtx
-	cLock.RUnlock()
+	respLock.RLock()
+	handlerCtx := respHandler
+	respLock.RUnlock()
 	if handlerCtx != nil {
 		v = handlerCtx(ctx, v)
 	}
@@ -77,31 +61,30 @@ func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
 }
 
 // SetErrorHandler sets the error handler, which is called on calling Error.
+// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
+// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
 func SetErrorHandler(handler func(error) (int, any)) {
-	lock.Lock()
-	defer lock.Unlock()
-	errorHandler = handler
+	errorLock.Lock()
+	defer errorLock.Unlock()
+	errorHandler = func(_ context.Context, err error) (int, any) {
+		return handler(err)
+	}
 }
 
 // SetErrorHandlerCtx sets the error handler, which is called on calling Error.
+// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
+// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
 func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) {
-	lock.Lock()
-	defer lock.Unlock()
-	errorHandlerCtx = handlerCtx
-}
-
-// SetCommonHandler sets the common handler, which is called on calling OkJson.
-func SetCommonHandler(handler func(any) any) {
-	cLock.Lock()
-	defer cLock.Unlock()
-	commonHandler = handler
+	errorLock.Lock()
+	defer errorLock.Unlock()
+	errorHandler = handlerCtx
 }
 
-// SetCommonHandlerCtx sets the common handler, which is called on calling OkJson.
-func SetCommonHandlerCtx(handlerCtx func(context.Context, any) any) {
-	cLock.Lock()
-	defer cLock.Unlock()
-	commonHandlerCtx = handlerCtx
+// SetResponseHandler sets the response handler, which is called on calling OkJson and OkJsonCtx.
+func SetResponseHandler(handler func(context.Context, any) any) {
+	respLock.Lock()
+	defer respLock.Unlock()
+	respHandler = handler
 }
 
 // WriteJson writes v as json string into w with code.
@@ -118,6 +101,21 @@ func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v any) {
 	}
 }
 
+func buildErrorHandler(ctx context.Context) func(error) (int, any) {
+	errorLock.RLock()
+	handlerCtx := errorHandler
+	errorLock.RUnlock()
+
+	var handler func(error) (int, any)
+	if handlerCtx != nil {
+		handler = func(err error) (int, any) {
+			return handlerCtx(ctx, err)
+		}
+	}
+
+	return handler
+}
+
 func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any),
 	writeJson func(w http.ResponseWriter, code int, v any),
 	fns ...func(w http.ResponseWriter, err error)) {

+ 75 - 16
rest/httpx/responses_test.go

@@ -3,6 +3,7 @@ package httpx
 import (
 	"context"
 	"errors"
+	"fmt"
 	"net/http"
 	"strings"
 	"testing"
@@ -80,14 +81,14 @@ func TestError(t *testing.T) {
 				headers: make(map[string][]string),
 			}
 			if test.errorHandler != nil {
-				lock.RLock()
+				errorLock.RLock()
 				prev := errorHandler
-				lock.RUnlock()
+				errorLock.RUnlock()
 				SetErrorHandler(test.errorHandler)
 				defer func() {
-					lock.Lock()
+					errorLock.Lock()
 					errorHandler = prev
-					lock.Unlock()
+					errorLock.Unlock()
 				}()
 			}
 			Error(&w, errors.New(test.input))
@@ -129,13 +130,71 @@ func TestOk(t *testing.T) {
 }
 
 func TestOkJson(t *testing.T) {
-	w := tracedResponseWriter{
-		headers: make(map[string][]string),
-	}
-	msg := message{Name: "anyone"}
-	OkJson(&w, msg)
-	assert.Equal(t, http.StatusOK, w.code)
-	assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
+	t.Run("no handler", func(t *testing.T) {
+		w := tracedResponseWriter{
+			headers: make(map[string][]string),
+		}
+		msg := message{Name: "anyone"}
+		OkJson(&w, msg)
+		assert.Equal(t, http.StatusOK, w.code)
+		assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
+	})
+
+	t.Run("with handler", func(t *testing.T) {
+		respLock.RLock()
+		prev := respHandler
+		respLock.RUnlock()
+		t.Cleanup(func() {
+			respLock.Lock()
+			respHandler = prev
+			respLock.Unlock()
+		})
+
+		SetResponseHandler(func(_ context.Context, v interface{}) any {
+			return fmt.Sprintf("hello %s", v.(message).Name)
+		})
+		w := tracedResponseWriter{
+			headers: make(map[string][]string),
+		}
+		msg := message{Name: "anyone"}
+		OkJson(&w, msg)
+		assert.Equal(t, http.StatusOK, w.code)
+		assert.Equal(t, `"hello anyone"`, w.builder.String())
+	})
+}
+
+func TestOkJsonCtx(t *testing.T) {
+	t.Run("no handler", func(t *testing.T) {
+		w := tracedResponseWriter{
+			headers: make(map[string][]string),
+		}
+		msg := message{Name: "anyone"}
+		OkJsonCtx(context.Background(), &w, msg)
+		assert.Equal(t, http.StatusOK, w.code)
+		assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
+	})
+
+	t.Run("with handler", func(t *testing.T) {
+		respLock.RLock()
+		prev := respHandler
+		respLock.RUnlock()
+		t.Cleanup(func() {
+			respLock.Lock()
+			respHandler = prev
+			respLock.Unlock()
+		})
+
+		SetResponseHandler(func(_ context.Context, v interface{}) any {
+			return fmt.Sprintf("hello %s", v.(message).Name)
+		})
+		w := tracedResponseWriter{
+			headers: make(map[string][]string),
+		}
+		msg := message{Name: "anyone"}
+		OkJsonCtx(context.Background(), &w, msg)
+		assert.Equal(t, http.StatusOK, w.code)
+		assert.Equal(t, `"hello anyone"`, w.builder.String())
+	})
 }
 
 func TestWriteJsonTimeout(t *testing.T) {
@@ -275,14 +334,14 @@ func TestErrorCtx(t *testing.T) {
 				headers: make(map[string][]string),
 			}
 			if test.errorHandlerCtx != nil {
-				lock.RLock()
-				prev := errorHandlerCtx
-				lock.RUnlock()
+				errorLock.RLock()
+				prev := errorHandler
+				errorLock.RUnlock()
 				SetErrorHandlerCtx(test.errorHandlerCtx)
 				defer func() {
-					lock.Lock()
+					errorLock.Lock()
 					test.errorHandlerCtx = prev
-					lock.Unlock()
+					errorLock.Unlock()
 				}()
 			}
 			ErrorCtx(context.Background(), &w, errors.New(test.input))