1
0
Эх сурвалжийг харах

fix: #2672 (#2681)

* fix: #2672

* chore: fix more cases

* chore: update deps

* chore: update deps

* chore: refactor

* chore: refactor

* chore: refactor
Kevin Wan 2 жил өмнө
parent
commit
fdc57d07d7

+ 5 - 5
core/breaker/nopbreaker.go

@@ -20,16 +20,16 @@ func (b noOpBreaker) Do(req func() error) error {
 	return req()
 }
 
-func (b noOpBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error {
+func (b noOpBreaker) DoWithAcceptable(req func() error, _ Acceptable) error {
 	return req()
 }
 
-func (b noOpBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
+func (b noOpBreaker) DoWithFallback(req func() error, _ func(err error) error) error {
 	return req()
 }
 
-func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
-	acceptable Acceptable) error {
+func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, _ func(err error) error,
+	_ Acceptable) error {
 	return req()
 }
 
@@ -38,5 +38,5 @@ type nopPromise struct{}
 func (p nopPromise) Accept() {
 }
 
-func (p nopPromise) Reject(reason string) {
+func (p nopPromise) Reject(_ string) {
 }

+ 12 - 0
core/mapping/unmarshaler.go

@@ -735,8 +735,16 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
 		default:
 			switch v := keythData.(type) {
 			case bool:
+				if dereffedElemKind != reflect.Bool {
+					return emptyValue, errTypeMismatch
+				}
+
 				targetValue.SetMapIndex(key, reflect.ValueOf(v))
 			case string:
+				if dereffedElemKind != reflect.String {
+					return emptyValue, errTypeMismatch
+				}
+
 				targetValue.SetMapIndex(key, reflect.ValueOf(v))
 			case json.Number:
 				target := reflect.New(dereffedElemType)
@@ -746,6 +754,10 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
 
 				targetValue.SetMapIndex(key, target.Elem())
 			default:
+				if dereffedElemKind != keythValue.Kind() {
+					return emptyValue, errTypeMismatch
+				}
+
 				targetValue.SetMapIndex(key, keythValue)
 			}
 		}

+ 65 - 0
core/mapping/unmarshaler_test.go

@@ -3563,6 +3563,71 @@ func TestGoogleUUID(t *testing.T) {
 	assert.Equal(t, "6ba7b810-9dad-11d1-80b4-00c04fd430c2", val.Uidp.String())
 }
 
+func TestUnmarshalJsonReaderWithTypeMismatchBool(t *testing.T) {
+	var req struct {
+		Params map[string]bool `json:"params"`
+	}
+	body := `{"params":{"a":"123"}}`
+	assert.Equal(t, errTypeMismatch, UnmarshalJsonReader(strings.NewReader(body), &req))
+}
+
+func TestUnmarshalJsonReaderWithTypeMismatchString(t *testing.T) {
+	var req struct {
+		Params map[string]string `json:"params"`
+	}
+	body := `{"params":{"a":{"a":123}}}`
+	assert.Equal(t, errTypeMismatch, UnmarshalJsonReader(strings.NewReader(body), &req))
+}
+
+func TestUnmarshalJsonReaderWithMismatchType(t *testing.T) {
+	type Req struct {
+		Params map[string]string `json:"params"`
+	}
+
+	var req Req
+	body := `{"params":{"a":{"a":123}}}`
+	assert.Equal(t, errTypeMismatch, UnmarshalJsonReader(strings.NewReader(body), &req))
+}
+
+func TestUnmarshalJsonReaderWithMismatchTypeBool(t *testing.T) {
+	type Req struct {
+		Params map[string]bool `json:"params"`
+	}
+
+	tests := []struct {
+		name  string
+		input string
+	}{
+		{
+			name:  "int",
+			input: `{"params":{"a":123}}`,
+		},
+		{
+			name:  "int",
+			input: `{"params":{"a":"123"}}`,
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			var req Req
+			assert.Equal(t, errTypeMismatch, UnmarshalJsonReader(strings.NewReader(test.input), &req))
+		})
+	}
+}
+
+func TestUnmarshalJsonReaderWithMismatchTypeBoolMap(t *testing.T) {
+	var req struct {
+		Params map[string]string `json:"params"`
+	}
+	assert.Equal(t, errTypeMismatch, UnmarshalJsonMap(map[string]interface{}{
+		"params": map[string]interface{}{
+			"a": true,
+		},
+	}, &req))
+}
+
 func BenchmarkDefaultValue(b *testing.B) {
 	for i := 0; i < b.N; i++ {
 		var a struct {

+ 8 - 1
core/mapping/utils.go

@@ -82,7 +82,14 @@ func ValidatePtr(v *reflect.Value) error {
 func convertType(kind reflect.Kind, str string) (interface{}, error) {
 	switch kind {
 	case reflect.Bool:
-		return str == "1" || strings.ToLower(str) == "true", nil
+		switch strings.ToLower(str) {
+		case "1", "true":
+			return true, nil
+		case "0", "false":
+			return false, nil
+		default:
+			return false, errTypeMismatch
+		}
 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 		intValue, err := strconv.ParseInt(str, 10, 64)
 		if err != nil {

+ 2 - 1
core/rescue/recover.go

@@ -4,7 +4,8 @@ import "github.com/zeromicro/go-zero/core/logx"
 
 // Recover is used with defer to do cleanup on panics.
 // Use it like:
-//  defer Recover(func() {})
+//
+//	defer Recover(func() {})
 func Recover(cleanups ...func()) {
 	for _, cleanup := range cleanups {
 		cleanup()

+ 0 - 1
core/stat/internal/cpu_linux.go

@@ -34,7 +34,6 @@ func initialize() {
 
 	cores = uint64(len(cpus))
 	quota = float64(len(cpus))
-
 	cq, err := cpuQuota()
 	if err == nil {
 		if cq != -1 {

+ 1 - 1
go.mod

@@ -24,7 +24,7 @@ require (
 	github.com/stretchr/testify v1.8.1
 	go.etcd.io/etcd/api/v3 v3.5.5
 	go.etcd.io/etcd/client/v3 v3.5.5
-	go.mongodb.org/mongo-driver v1.11.0
+	go.mongodb.org/mongo-driver v1.11.1
 	go.opentelemetry.io/otel v1.10.0
 	go.opentelemetry.io/otel/exporters/jaeger v1.10.0
 	go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.10.0

+ 2 - 2
go.sum

@@ -840,8 +840,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.5 h1:9S0JUVvmrVl7wCF39iTQthdaaNIiAaQbmK75ogO6
 go.etcd.io/etcd/client/pkg/v3 v3.5.5/go.mod h1:ggrwbk069qxpKPq8/FKkQ3Xq9y39kbFR4LnKszpRXeQ=
 go.etcd.io/etcd/client/v3 v3.5.5 h1:q++2WTJbUgpQu4B6hCuT7VkdwaTP7Qz6Daak3WzbrlI=
 go.etcd.io/etcd/client/v3 v3.5.5/go.mod h1:aApjR4WGlSumpnJ2kloS75h6aHUmAyaPLjHMxpc7E7c=
-go.mongodb.org/mongo-driver v1.11.0 h1:FZKhBSTydeuffHj9CBjXlR8vQLee1cQyTWYPA6/tqiE=
-go.mongodb.org/mongo-driver v1.11.0/go.mod h1:s7p5vEtfbeR1gYi6pnj3c3/urpbLv2T5Sfd6Rp2HBB8=
+go.mongodb.org/mongo-driver v1.11.1 h1:QP0znIRTuL0jf1oBQoAoM0C6ZJfBK4kx0Uumtv1A7w8=
+go.mongodb.org/mongo-driver v1.11.1/go.mod h1:s7p5vEtfbeR1gYi6pnj3c3/urpbLv2T5Sfd6Rp2HBB8=
 go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
 go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
 go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=

+ 2 - 5
internal/devserver/server.go

@@ -9,15 +9,12 @@ import (
 
 	"github.com/felixge/fgprof"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
-
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/threading"
 	"github.com/zeromicro/go-zero/internal/health"
 )
 
-var (
-	once sync.Once
-)
+var once sync.Once
 
 // Server is inner http server, expose some useful observability information of app.
 // For example health check, metrics and pprof.
@@ -68,7 +65,7 @@ func (s *Server) StartAsync() {
 	s.addRoutes()
 	threading.GoSafe(func() {
 		addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
-		logx.Infof("Starting inner http server at %s", addr)
+		logx.Infof("Starting dev http server at %s", addr)
 		if err := http.ListenAndServe(addr, s.server); err != nil {
 			logx.Error(err)
 		}

+ 20 - 15
internal/health/health.go

@@ -1,7 +1,9 @@
 package health
 
 import (
+	"fmt"
 	"net/http"
+	"strings"
 	"sync"
 
 	"github.com/zeromicro/go-zero/core/syncx"
@@ -41,6 +43,18 @@ func AddProbe(probe Probe) {
 	defaultHealthManager.addProbe(probe)
 }
 
+// CreateHttpHandler create health http handler base on given probe.
+func CreateHttpHandler() http.HandlerFunc {
+	return func(w http.ResponseWriter, _ *http.Request) {
+		if defaultHealthManager.IsReady() {
+			_, _ = w.Write([]byte("OK"))
+		} else {
+			http.Error(w, "Service Unavailable\n"+defaultHealthManager.verboseInfo(),
+				http.StatusServiceUnavailable)
+		}
+	}
+}
+
 // NewHealthManager returns a new healthManager.
 func NewHealthManager(name string) Probe {
 	return &healthManager{
@@ -102,6 +116,7 @@ func (p *comboHealthManager) IsReady() bool {
 			return false
 		}
 	}
+
 	return true
 }
 
@@ -109,15 +124,16 @@ func (p *comboHealthManager) verboseInfo() string {
 	p.mu.Lock()
 	defer p.mu.Unlock()
 
-	var info string
+	var info strings.Builder
 	for _, probe := range p.probes {
 		if probe.IsReady() {
-			info += probe.Name() + " is ready; \n"
+			info.WriteString(fmt.Sprintf("%s is ready\n", probe.Name()))
 		} else {
-			info += probe.Name() + " is not ready; \n"
+			info.WriteString(fmt.Sprintf("%s is not ready\n", probe.Name()))
 		}
 	}
-	return info
+
+	return info.String()
 }
 
 // addProbe add components probe to comboHealthManager.
@@ -127,14 +143,3 @@ func (p *comboHealthManager) addProbe(probe Probe) {
 
 	p.probes = append(p.probes, probe)
 }
-
-// CreateHttpHandler create health http handler base on given probe.
-func CreateHttpHandler() http.HandlerFunc {
-	return func(w http.ResponseWriter, request *http.Request) {
-		if defaultHealthManager.IsReady() {
-			_, _ = w.Write([]byte("OK"))
-		} else {
-			http.Error(w, "Service Unavailable\n"+defaultHealthManager.verboseInfo(), http.StatusServiceUnavailable)
-		}
-	}
-}

+ 10 - 10
internal/health/health_test.go

@@ -54,7 +54,7 @@ func TestComboHealthManager(t *testing.T) {
 	})
 
 	t.Run("concurrent add probes", func(t *testing.T) {
-		chm2 := newComboHealthManager()
+		chm := newComboHealthManager()
 
 		var wg sync.WaitGroup
 		wg.Add(10)
@@ -62,28 +62,28 @@ func TestComboHealthManager(t *testing.T) {
 			go func() {
 				hm := NewHealthManager(probeName)
 				hm.MarkReady()
-				chm2.addProbe(hm)
+				chm.addProbe(hm)
 				wg.Done()
 			}()
 		}
 		wg.Wait()
-		assert.True(t, chm2.IsReady())
+		assert.True(t, chm.IsReady())
 	})
 
 	t.Run("markReady and markNotReady", func(t *testing.T) {
-		chm2 := newComboHealthManager()
+		chm := newComboHealthManager()
 
 		for i := 0; i < 10; i++ {
 			hm := NewHealthManager(probeName)
-			chm2.addProbe(hm)
+			chm.addProbe(hm)
 		}
-		assert.False(t, chm2.IsReady())
+		assert.False(t, chm.IsReady())
 
-		chm2.MarkReady()
-		assert.True(t, chm2.IsReady())
+		chm.MarkReady()
+		assert.True(t, chm.IsReady())
 
-		chm2.MarkNotReady()
-		assert.False(t, chm2.IsReady())
+		chm.MarkNotReady()
+		assert.False(t, chm.IsReady())
 	})
 }
 

+ 21 - 14
rest/chain/chain.go

@@ -37,35 +37,41 @@ func New(middlewares ...Middleware) Chain {
 
 // Append extends a chain, adding the specified middlewares as the last ones in the request flow.
 //
-//     c := chain.New(m1, m2)
-//     c.Append(m3, m4)
-//     // requests in c go m1 -> m2 -> m3 -> m4
+//	c := chain.New(m1, m2)
+//	c.Append(m3, m4)
+//	// requests in c go m1 -> m2 -> m3 -> m4
 func (c chain) Append(middlewares ...Middleware) Chain {
 	return chain{middlewares: join(c.middlewares, middlewares)}
 }
 
 // Prepend extends a chain by adding the specified chain as the first one in the request flow.
 //
-//     c := chain.New(m3, m4)
-//     c1 := chain.New(m1, m2)
-//     c.Prepend(c1)
-//     // requests in c go m1 -> m2 -> m3 -> m4
+//	c := chain.New(m3, m4)
+//	c1 := chain.New(m1, m2)
+//	c.Prepend(c1)
+//	// requests in c go m1 -> m2 -> m3 -> m4
 func (c chain) Prepend(middlewares ...Middleware) Chain {
 	return chain{middlewares: join(middlewares, c.middlewares)}
 }
 
 // Then chains the middleware and returns the final http.Handler.
-//     New(m1, m2, m3).Then(h)
+//
+//	New(m1, m2, m3).Then(h)
+//
 // is equivalent to:
-//     m1(m2(m3(h)))
+//
+//	m1(m2(m3(h)))
+//
 // When the request comes in, it will be passed to m1, then m2, then m3
 // and finally, the given handler
 // (assuming every middleware calls the following one).
 //
 // A chain can be safely reused by calling Then() several times.
-//     stdStack := chain.New(ratelimitHandler, csrfHandler)
-//     indexPipe = stdStack.Then(indexHandler)
-//     authPipe = stdStack.Then(authHandler)
+//
+//	stdStack := chain.New(ratelimitHandler, csrfHandler)
+//	indexPipe = stdStack.Then(indexHandler)
+//	authPipe = stdStack.Then(authHandler)
+//
 // Note that middlewares are called on every call to Then() or ThenFunc()
 // and thus several instances of the same middleware will be created
 // when a chain is reused in this way.
@@ -88,8 +94,9 @@ func (c chain) Then(h http.Handler) http.Handler {
 // a HandlerFunc instead of a Handler.
 //
 // The following two statements are equivalent:
-//     c.Then(http.HandlerFunc(fn))
-//     c.ThenFunc(fn)
+//
+//	c.Then(http.HandlerFunc(fn))
+//	c.ThenFunc(fn)
 //
 // ThenFunc provides all the guarantees of Then.
 func (c chain) ThenFunc(fn http.HandlerFunc) http.Handler {

+ 2 - 2
tools/goctl/rpc/generator/gen_test.go

@@ -40,7 +40,7 @@ func TestRpcGenerate(t *testing.T) {
 	// case go path
 	t.Run("GOPATH", func(t *testing.T) {
 		ctx := &ZRpcContext{
-			Src:        "./test.proto",
+			Src: "./test.proto",
 			ProtocCmd: fmt.Sprintf("protoc -I=%s test.proto --go_out=%s --go_opt=Mbase/common.proto=./base --go-grpc_out=%s",
 				common, projectDir, projectDir),
 			IsGooglePlugin: true,
@@ -71,7 +71,7 @@ func TestRpcGenerate(t *testing.T) {
 
 		projectDir = filepath.Join(workDir, projectName)
 		ctx := &ZRpcContext{
-			Src:        "./test.proto",
+			Src: "./test.proto",
 			ProtocCmd: fmt.Sprintf("protoc -I=%s test.proto --go_out=%s --go_opt=Mbase/common.proto=./base --go-grpc_out=%s",
 				common, projectDir, projectDir),
 			IsGooglePlugin: true,

+ 7 - 8
zrpc/internal/rpcserver.go

@@ -4,13 +4,12 @@ import (
 	"fmt"
 	"net"
 
-	"google.golang.org/grpc"
-	"google.golang.org/grpc/health/grpc_health_v1"
-
 	"github.com/zeromicro/go-zero/core/proc"
 	"github.com/zeromicro/go-zero/core/stat"
 	"github.com/zeromicro/go-zero/internal/health"
 	"github.com/zeromicro/go-zero/zrpc/internal/serverinterceptors"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/health/grpc_health_v1"
 )
 
 const probeNamePrefix = "zrpc"
@@ -25,25 +24,25 @@ type (
 	}
 
 	rpcServer struct {
-		name string
 		*baseRpcServer
+		name          string
 		healthManager health.Probe
 	}
 )
 
 // NewRpcServer returns a Server.
-func NewRpcServer(address string, opts ...ServerOption) Server {
+func NewRpcServer(addr string, opts ...ServerOption) Server {
 	var options rpcServerOptions
 	for _, opt := range opts {
 		opt(&options)
 	}
 	if options.metrics == nil {
-		options.metrics = stat.NewMetrics(address)
+		options.metrics = stat.NewMetrics(addr)
 	}
 
 	return &rpcServer{
-		baseRpcServer: newBaseRpcServer(address, &options),
-		healthManager: health.NewHealthManager(fmt.Sprintf("%s-%s", probeNamePrefix, address)),
+		baseRpcServer: newBaseRpcServer(addr, &options),
+		healthManager: health.NewHealthManager(fmt.Sprintf("%s-%s", probeNamePrefix, addr)),
 	}
 }