Преглед на файлове

feat: support google.api.http in gateway (#2161)

Kevin Wan преди 2 години
родител
ревизия
0dd2768d09
променени са 4 файла, в които са добавени 94 реда и са изтрити 10 реда
  1. 3 1
      gateway/config.go
  2. 71 3
      gateway/internal/descriptorsource.go
  3. 5 1
      gateway/internal/descriptorsource_test.go
  4. 15 5
      gateway/server.go

+ 3 - 1
gateway/config.go

@@ -31,6 +31,8 @@ type (
 		Grpc zrpc.RpcClientConf
 		// ProtoSet is the file of proto set, like hello.pb
 		ProtoSet string `json:",optional"`
-		Mapping  []mapping
+		// Mapping is the mapping between gateway routes and upstream rpc methods.
+		// Keep it blank if annotations are added in rpc methods.
+		Mapping []mapping `json:",optional"`
 	}
 )

+ 71 - 3
gateway/internal/descriptorsource.go

@@ -2,19 +2,29 @@ package internal
 
 import (
 	"fmt"
+	"net/http"
+	"strings"
 
 	"github.com/fullstorydev/grpcurl"
 	"github.com/jhump/protoreflect/desc"
+	"google.golang.org/genproto/googleapis/api/annotations"
+	"google.golang.org/protobuf/proto"
 )
 
+type Method struct {
+	HttpMethod string
+	HttpPath   string
+	RpcPath    string
+}
+
 // GetMethods returns all methods of the given grpcurl.DescriptorSource.
-func GetMethods(source grpcurl.DescriptorSource) ([]string, error) {
+func GetMethods(source grpcurl.DescriptorSource) ([]Method, error) {
 	svcs, err := source.ListServices()
 	if err != nil {
 		return nil, err
 	}
 
-	var methods []string
+	var methods []Method
 	for _, svc := range svcs {
 		d, err := source.FindSymbol(svc)
 		if err != nil {
@@ -25,10 +35,68 @@ func GetMethods(source grpcurl.DescriptorSource) ([]string, error) {
 		case *desc.ServiceDescriptor:
 			svcMethods := val.GetMethods()
 			for _, method := range svcMethods {
-				methods = append(methods, fmt.Sprintf("%s/%s", svc, method.GetName()))
+				rpcPath := fmt.Sprintf("%s/%s", svc, method.GetName())
+				ext := proto.GetExtension(method.GetMethodOptions(), annotations.E_Http)
+				if ext == nil {
+					methods = append(methods, Method{
+						RpcPath: rpcPath,
+					})
+					continue
+				}
+
+				httpExt, ok := ext.(*annotations.HttpRule)
+				if !ok {
+					methods = append(methods, Method{
+						RpcPath: rpcPath,
+					})
+					continue
+				}
+
+				switch rule := httpExt.GetPattern().(type) {
+				case *annotations.HttpRule_Get:
+					methods = append(methods, Method{
+						HttpMethod: http.MethodGet,
+						HttpPath:   adjustHttpPath(rule.Get),
+						RpcPath:    rpcPath,
+					})
+				case *annotations.HttpRule_Post:
+					methods = append(methods, Method{
+						HttpMethod: http.MethodPost,
+						HttpPath:   adjustHttpPath(rule.Post),
+						RpcPath:    rpcPath,
+					})
+				case *annotations.HttpRule_Put:
+					methods = append(methods, Method{
+						HttpMethod: http.MethodPut,
+						HttpPath:   adjustHttpPath(rule.Put),
+						RpcPath:    rpcPath,
+					})
+				case *annotations.HttpRule_Delete:
+					methods = append(methods, Method{
+						HttpMethod: http.MethodDelete,
+						HttpPath:   adjustHttpPath(rule.Delete),
+						RpcPath:    rpcPath,
+					})
+				case *annotations.HttpRule_Patch:
+					methods = append(methods, Method{
+						HttpMethod: http.MethodPatch,
+						HttpPath:   adjustHttpPath(rule.Patch),
+						RpcPath:    rpcPath,
+					})
+				default:
+					methods = append(methods, Method{
+						RpcPath: rpcPath,
+					})
+				}
 			}
 		}
 	}
 
 	return methods, nil
 }
+
+func adjustHttpPath(path string) string {
+	path = strings.ReplaceAll(path, "{", ":")
+	path = strings.ReplaceAll(path, "}", "")
+	return path
+}

+ 5 - 1
gateway/internal/descriptorsource_test.go

@@ -25,5 +25,9 @@ func TestGetMethods(t *testing.T) {
 	assert.Nil(t, err)
 	methods, err := GetMethods(source)
 	assert.Nil(t, err)
-	assert.EqualValues(t, []string{"hello.Hello/Ping"}, methods)
+	assert.EqualValues(t, []Method{
+		{
+			RpcPath: "hello.Hello/Ping",
+		},
+	}, methods)
 }

+ 15 - 5
gateway/server.go

@@ -66,11 +66,21 @@ func (s *Server) build() error {
 			return
 		}
 
+		resolver := grpcurl.AnyResolverFromDescriptorSource(source)
+		for _, m := range methods {
+			if len(m.HttpMethod) > 0 && len(m.HttpPath) > 0 {
+				writer.Write(rest.Route{
+					Method:  m.HttpMethod,
+					Path:    m.HttpPath,
+					Handler: s.buildHandler(source, resolver, cli, m.RpcPath),
+				})
+			}
+		}
+
 		methodSet := make(map[string]struct{})
 		for _, m := range methods {
-			methodSet[m] = struct{}{}
+			methodSet[m.RpcPath] = struct{}{}
 		}
-		resolver := grpcurl.AnyResolverFromDescriptorSource(source)
 		for _, m := range up.Mapping {
 			if _, ok := methodSet[m.RpcPath]; !ok {
 				cancel(fmt.Errorf("rpc method %s not found", m.RpcPath))
@@ -80,7 +90,7 @@ func (s *Server) build() error {
 			writer.Write(rest.Route{
 				Method:  strings.ToUpper(m.Method),
 				Path:    m.Path,
-				Handler: s.buildHandler(source, resolver, cli, m),
+				Handler: s.buildHandler(source, resolver, cli, m.RpcPath),
 			})
 		}
 	}, func(pipe <-chan interface{}, cancel func(error)) {
@@ -92,7 +102,7 @@ func (s *Server) build() error {
 }
 
 func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver,
-	cli zrpc.Client, m mapping) func(http.ResponseWriter, *http.Request) {
+	cli zrpc.Client, rpcPath string) func(http.ResponseWriter, *http.Request) {
 	return func(w http.ResponseWriter, r *http.Request) {
 		handler := &grpcurl.DefaultEventHandler{
 			Out: w,
@@ -110,7 +120,7 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
 		defer can()
 
 		w.Header().Set(httpx.ContentType, httpx.JsonContentType)
-		if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), m.RpcPath, internal.BuildHeaders(r.Header),
+		if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), rpcPath, internal.BuildHeaders(r.Header),
 			handler, parser.Next); err != nil {
 			httpx.Error(w, err)
 		}