Forráskód Böngészése

feat: verify RpcPath on startup (#2159)

* feat: verify RpcPath on startup

* feat: support http header Grpc-Timeout
Kevin Wan 2 éve
szülő
commit
557383fbbf

+ 2 - 2
gateway/config.go

@@ -21,8 +21,8 @@ type (
 		Method string
 		// Path is the HTTP path.
 		Path string
-		// Rpc is the gRPC rpc method, with format of package.service/method
-		Rpc string
+		// RpcPath is the gRPC rpc method, with format of package.service/method
+		RpcPath string
 	}
 
 	// upstream is the configuration for upstream.

+ 34 - 0
gateway/internal/descriptorsource.go

@@ -0,0 +1,34 @@
+package internal
+
+import (
+	"fmt"
+
+	"github.com/fullstorydev/grpcurl"
+	"github.com/jhump/protoreflect/desc"
+)
+
+// GetMethods returns all methods of the given grpcurl.DescriptorSource.
+func GetMethods(source grpcurl.DescriptorSource) ([]string, error) {
+	svcs, err := source.ListServices()
+	if err != nil {
+		return nil, err
+	}
+
+	var methods []string
+	for _, svc := range svcs {
+		d, err := source.FindSymbol(svc)
+		if err != nil {
+			return nil, err
+		}
+
+		switch val := d.(type) {
+		case *desc.ServiceDescriptor:
+			svcMethods := val.GetMethods()
+			for _, method := range svcMethods {
+				methods = append(methods, fmt.Sprintf("%s/%s", svc, method.GetName()))
+			}
+		}
+	}
+
+	return methods, nil
+}

+ 29 - 0
gateway/internal/descriptorsource_test.go

@@ -0,0 +1,29 @@
+package internal
+
+import (
+	"encoding/base64"
+	"io/ioutil"
+	"os"
+	"testing"
+
+	"github.com/fullstorydev/grpcurl"
+	"github.com/stretchr/testify/assert"
+	"github.com/zeromicro/go-zero/core/hash"
+)
+
+const b64pb = `CpgBCgtoZWxsby5wcm90bxIFaGVsbG8iHQoHUmVxdWVzdBISCgRwaW5nGAEgASgJUgRwaW5nIh4KCFJlc3BvbnNlEhIKBHBvbmcYASABKAlSBHBvbmcyMAoFSGVsbG8SJwoEUGluZxIOLmhlbGxvLlJlcXVlc3QaDy5oZWxsby5SZXNwb25zZUIJWgcuL2hlbGxvYgZwcm90bzM=`
+
+func TestGetMethods(t *testing.T) {
+	tmpfile, err := ioutil.TempFile(os.TempDir(), hash.Md5Hex([]byte(b64pb)))
+	assert.Nil(t, err)
+	b, err := base64.StdEncoding.DecodeString(b64pb)
+	assert.Nil(t, err)
+	assert.Nil(t, ioutil.WriteFile(tmpfile.Name(), b, os.ModeTemporary))
+	defer os.Remove(tmpfile.Name())
+
+	source, err := grpcurl.DescriptorSourceFromProtoSets(tmpfile.Name())
+	assert.Nil(t, err)
+	methods, err := GetMethods(source)
+	assert.Nil(t, err)
+	assert.EqualValues(t, []string{"hello.Hello/Ping"}, methods)
+}

+ 3 - 2
gateway/headerbuilder.go → gateway/internal/headerbuilder.go

@@ -1,4 +1,4 @@
-package gateway
+package internal
 
 import (
 	"fmt"
@@ -11,7 +11,8 @@ const (
 	metadataPrefix       = "gateway-"
 )
 
-func buildHeaders(header http.Header) []string {
+// BuildHeaders builds the headers for the gateway from HTTP headers.
+func BuildHeaders(header http.Header) []string {
 	var headers []string
 
 	for k, v := range header {

+ 3 - 3
gateway/headerbuilder_test.go → gateway/internal/headerbuilder_test.go

@@ -1,4 +1,4 @@
-package gateway
+package internal
 
 import (
 	"net/http/httptest"
@@ -10,12 +10,12 @@ import (
 func TestBuildHeadersNoValue(t *testing.T) {
 	req := httptest.NewRequest("GET", "/", nil)
 	req.Header.Add("a", "b")
-	assert.Nil(t, buildHeaders(req.Header))
+	assert.Nil(t, BuildHeaders(req.Header))
 }
 
 func TestBuildHeadersWithValues(t *testing.T) {
 	req := httptest.NewRequest("GET", "/", nil)
 	req.Header.Add("grpc-metadata-a", "b")
 	req.Header.Add("grpc-metadata-b", "b")
-	assert.EqualValues(t, []string{"gateway-A:b", "gateway-B:b"}, buildHeaders(req.Header))
+	assert.EqualValues(t, []string{"gateway-A:b", "gateway-B:b"}, BuildHeaders(req.Header))
 }

+ 13 - 12
gateway/requestparser.go → gateway/internal/requestparser.go

@@ -1,4 +1,4 @@
-package gateway
+package internal
 
 import (
 	"bytes"
@@ -11,17 +11,8 @@ import (
 	"github.com/zeromicro/go-zero/rest/pathvar"
 )
 
-func buildJsonRequestParser(m map[string]interface{}, resolver jsonpb.AnyResolver) (
-	grpcurl.RequestParser, error) {
-	var buf bytes.Buffer
-	if err := json.NewEncoder(&buf).Encode(m); err != nil {
-		return nil, err
-	}
-
-	return grpcurl.NewJSONRequestParser(&buf, resolver), nil
-}
-
-func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) {
+// NewRequestParser creates a new request parser from the given http.Request and resolver.
+func NewRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) {
 	vars := pathvar.Vars(r)
 	params, err := httpx.GetFormValues(r)
 	if err != nil {
@@ -50,3 +41,13 @@ func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.Req
 
 	return buildJsonRequestParser(m, resolver)
 }
+
+func buildJsonRequestParser(m map[string]interface{}, resolver jsonpb.AnyResolver) (
+	grpcurl.RequestParser, error) {
+	var buf bytes.Buffer
+	if err := json.NewEncoder(&buf).Encode(m); err != nil {
+		return nil, err
+	}
+
+	return grpcurl.NewJSONRequestParser(&buf, resolver), nil
+}

+ 7 - 7
gateway/requestparser_test.go → gateway/internal/requestparser_test.go

@@ -1,4 +1,4 @@
-package gateway
+package internal
 
 import (
 	"net/http/httptest"
@@ -11,7 +11,7 @@ import (
 
 func TestNewRequestParserNoVar(t *testing.T) {
 	req := httptest.NewRequest("GET", "/", nil)
-	parser, err := newRequestParser(req, nil)
+	parser, err := NewRequestParser(req, nil)
 	assert.Nil(t, err)
 	assert.NotNil(t, parser)
 }
@@ -19,14 +19,14 @@ func TestNewRequestParserNoVar(t *testing.T) {
 func TestNewRequestParserWithVars(t *testing.T) {
 	req := httptest.NewRequest("GET", "/", nil)
 	req = pathvar.WithVars(req, map[string]string{"a": "b"})
-	parser, err := newRequestParser(req, nil)
+	parser, err := NewRequestParser(req, nil)
 	assert.Nil(t, err)
 	assert.NotNil(t, parser)
 }
 
 func TestNewRequestParserNoVarWithBody(t *testing.T) {
 	req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"}`))
-	parser, err := newRequestParser(req, nil)
+	parser, err := NewRequestParser(req, nil)
 	assert.Nil(t, err)
 	assert.NotNil(t, parser)
 }
@@ -34,7 +34,7 @@ func TestNewRequestParserNoVarWithBody(t *testing.T) {
 func TestNewRequestParserWithVarsWithBody(t *testing.T) {
 	req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"}`))
 	req = pathvar.WithVars(req, map[string]string{"c": "d"})
-	parser, err := newRequestParser(req, nil)
+	parser, err := NewRequestParser(req, nil)
 	assert.Nil(t, err)
 	assert.NotNil(t, parser)
 }
@@ -42,14 +42,14 @@ func TestNewRequestParserWithVarsWithBody(t *testing.T) {
 func TestNewRequestParserWithVarsWithWrongBody(t *testing.T) {
 	req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"`))
 	req = pathvar.WithVars(req, map[string]string{"c": "d"})
-	parser, err := newRequestParser(req, nil)
+	parser, err := NewRequestParser(req, nil)
 	assert.NotNil(t, err)
 	assert.Nil(t, parser)
 }
 
 func TestNewRequestParserWithForm(t *testing.T) {
 	req := httptest.NewRequest("GET", "/val?a=b", nil)
-	parser, err := newRequestParser(req, nil)
+	parser, err := NewRequestParser(req, nil)
 	assert.Nil(t, err)
 	assert.NotNil(t, parser)
 }

+ 19 - 0
gateway/internal/timeout.go

@@ -0,0 +1,19 @@
+package internal
+
+import (
+	"net/http"
+	"time"
+)
+
+const grpcTimeoutHeader = "Grpc-Timeout"
+
+// GetTimeout returns the timeout from the header, if not set, returns the default timeout.
+func GetTimeout(header http.Header, defaultTimeout time.Duration) time.Duration {
+	if timeout := header.Get(grpcTimeoutHeader); len(timeout) > 0 {
+		if t, err := time.ParseDuration(timeout); err == nil {
+			return t
+		}
+	}
+
+	return defaultTimeout
+}

+ 22 - 0
gateway/internal/timeout_test.go

@@ -0,0 +1,22 @@
+package internal
+
+import (
+	"net/http/httptest"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestGetTimeout(t *testing.T) {
+	req := httptest.NewRequest("GET", "/", nil)
+	req.Header.Set(grpcTimeoutHeader, "1s")
+	timeout := GetTimeout(req.Header, time.Second*5)
+	assert.Equal(t, time.Second, timeout)
+}
+
+func TestGetTimeoutDefault(t *testing.T) {
+	req := httptest.NewRequest("GET", "/", nil)
+	timeout := GetTimeout(req.Header, time.Second*5)
+	assert.Equal(t, time.Second*5, timeout)
+}

+ 2 - 2
gateway/readme.md

@@ -35,7 +35,7 @@ Upstreams:
     Mapping:
       - Method: get
         Path: /pingHello/:ping
-        Rpc: hello.Hello/Ping
+        RpcPath: hello.Hello/Ping
   - Grpc:
       Endpoints:
         - localhost:8081
@@ -43,7 +43,7 @@ Upstreams:
     Mapping:
       - Method: post
         Path: /pingWorld
-        Rpc: world.World/Ping
+        RpcPath: world.World/Ping
 ```
 
 ## Generate ProtoSet files

+ 21 - 3
gateway/server.go

@@ -2,6 +2,7 @@ package gateway
 
 import (
 	"context"
+	"fmt"
 	"net/http"
 	"strings"
 	"time"
@@ -11,6 +12,7 @@ import (
 	"github.com/jhump/protoreflect/grpcreflect"
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/mr"
+	"github.com/zeromicro/go-zero/gateway/internal"
 	"github.com/zeromicro/go-zero/rest"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/zrpc"
@@ -58,8 +60,23 @@ func (s *Server) build() error {
 			return
 		}
 
+		methods, err := internal.GetMethods(source)
+		if err != nil {
+			cancel(err)
+			return
+		}
+
+		methodSet := make(map[string]struct{})
+		for _, m := range methods {
+			methodSet[m] = struct{}{}
+		}
 		resolver := grpcurl.AnyResolverFromDescriptorSource(source)
 		for _, m := range up.Mapping {
+			if _, ok := methodSet[m.RpcPath]; !ok {
+				cancel(fmt.Errorf("rpc method %s not found", m.RpcPath))
+				return
+			}
+
 			writer.Write(rest.Route{
 				Method:  strings.ToUpper(m.Method),
 				Path:    m.Path,
@@ -82,15 +99,16 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
 			Formatter: grpcurl.NewJSONFormatter(true,
 				grpcurl.AnyResolverFromDescriptorSource(source)),
 		}
-		parser, err := newRequestParser(r, resolver)
+		parser, err := internal.NewRequestParser(r, resolver)
 		if err != nil {
 			httpx.Error(w, err)
 			return
 		}
 
-		ctx, can := context.WithTimeout(r.Context(), s.timeout)
+		timeout := internal.GetTimeout(r.Header, s.timeout)
+		ctx, can := context.WithTimeout(r.Context(), timeout)
 		defer can()
-		if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), m.Rpc, buildHeaders(r.Header),
+		if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), m.RpcPath, internal.BuildHeaders(r.Header),
 			handler, parser.Next); err != nil {
 			httpx.Error(w, err)
 		}