Преглед изворни кода

chore: refine rest validator (#2928)

* chore: refine rest validator

* chore: add more tests

* chore: reformat code

* chore: add comments
Kevin Wan пре 2 година
родитељ
комит
66be213346
2 измењених фајлова са 77 додато и 11 уклоњено
  1. 15 8
      rest/httpx/requests.go
  2. 62 3
      rest/httpx/requests_test.go

+ 15 - 8
rest/httpx/requests.go

@@ -4,6 +4,7 @@ import (
 	"io"
 	"net/http"
 	"strings"
+	"sync/atomic"
 
 	"github.com/zeromicro/go-zero/core/mapping"
 	"github.com/zeromicro/go-zero/rest/internal/encoding"
@@ -23,15 +24,13 @@ const (
 var (
 	formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
 	pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
-	xValidator      Validator
+	validator       atomic.Value
 )
 
+// Validator defines the interface for validating the request.
 type Validator interface {
-	Validate(data interface{}, lang string) error
-}
-
-func SetValidator(validator Validator) {
-	xValidator = validator
+	// Validate validates the request and parsed data.
+	Validate(r *http.Request, data any) error
 }
 
 // Parse parses the request.
@@ -52,9 +51,10 @@ func Parse(r *http.Request, v any) error {
 		return err
 	}
 
-	if xValidator != nil {
-		return xValidator.Validate(v, r.Header.Get("Accept-Language"))
+	if val := validator.Load(); val != nil {
+		return val.(Validator).Validate(r, v)
 	}
+
 	return nil
 }
 
@@ -117,6 +117,13 @@ func ParsePath(r *http.Request, v any) error {
 	return pathUnmarshaler.Unmarshal(m, v)
 }
 
+// SetValidator sets the validator.
+// The validator is used to validate the request, only called in Parse,
+// not in ParseHeaders, ParseForm, ParseHeader, ParseJsonBody, ParsePath.
+func SetValidator(val Validator) {
+	validator.Store(val)
+}
+
 func withJsonBody(r *http.Request) bool {
 	return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson)
 }

+ 62 - 3
rest/httpx/requests_test.go

@@ -1,8 +1,10 @@
 package httpx
 
 import (
+	"errors"
 	"net/http"
 	"net/http/httptest"
+	"reflect"
 	"strconv"
 	"strings"
 	"testing"
@@ -207,9 +209,23 @@ func TestParseJsonBody(t *testing.T) {
 		r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
 		r.Header.Set(ContentType, header.JsonContentType)
 
-		assert.Nil(t, Parse(r, &v))
-		assert.Equal(t, "kevin", v.Name)
-		assert.Equal(t, 18, v.Age)
+		if assert.NoError(t, Parse(r, &v)) {
+			assert.Equal(t, "kevin", v.Name)
+			assert.Equal(t, 18, v.Age)
+		}
+	})
+
+	t.Run("bad body", func(t *testing.T) {
+		var v struct {
+			Name string `json:"name"`
+			Age  int    `json:"age"`
+		}
+
+		body := `{"name":"kevin", "ag": 18}`
+		r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
+		r.Header.Set(ContentType, header.JsonContentType)
+
+		assert.Error(t, Parse(r, &v))
 	})
 
 	t.Run("hasn't body", func(t *testing.T) {
@@ -308,6 +324,36 @@ func TestParseHeaders_Error(t *testing.T) {
 	assert.NotNil(t, Parse(r, &v))
 }
 
+func TestParseWithValidator(t *testing.T) {
+	SetValidator(mockValidator{})
+	var v struct {
+		Name    string  `form:"name"`
+		Age     int     `form:"age"`
+		Percent float64 `form:"percent,optional"`
+	}
+
+	r, err := http.NewRequest(http.MethodGet, "/a?name=hello&age=18&percent=3.4", http.NoBody)
+	assert.Nil(t, err)
+	if assert.NoError(t, Parse(r, &v)) {
+		assert.Equal(t, "hello", v.Name)
+		assert.Equal(t, 18, v.Age)
+		assert.Equal(t, 3.4, v.Percent)
+	}
+}
+
+func TestParseWithValidatorWithError(t *testing.T) {
+	SetValidator(mockValidator{})
+	var v struct {
+		Name    string  `form:"name"`
+		Age     int     `form:"age"`
+		Percent float64 `form:"percent,optional"`
+	}
+
+	r, err := http.NewRequest(http.MethodGet, "/a?name=world&age=18&percent=3.4", http.NoBody)
+	assert.Nil(t, err)
+	assert.Error(t, Parse(r, &v))
+}
+
 func BenchmarkParseRaw(b *testing.B) {
 	r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody)
 	if err != nil {
@@ -351,3 +397,16 @@ func BenchmarkParseAuto(b *testing.B) {
 		}
 	}
 }
+
+type mockValidator struct{}
+
+func (m mockValidator) Validate(r *http.Request, data any) error {
+	if r.URL.Path == "/a" {
+		val := reflect.ValueOf(data).Elem().FieldByName("Name").String()
+		if val != "hello" {
+			return errors.New("name is not hello")
+		}
+	}
+
+	return nil
+}