浏览代码

api add middleware support (#140)

* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl support import

* goctl support import

* support return ()

* revert

* refactor and rename folder to group

* remove no need

* add anonymous annotation

* optimized

* rename

* rename

* update test

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* simple logic

* should reverse middlewares

* optimized

* optimized

* rename

Co-authored-by: kingxt <dream4kingxt@163.com>
kingxt 4 年之前
父节点
当前提交
aa3c391919

+ 7 - 0
rest/server.go

@@ -103,6 +103,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption {
 	}
 	}
 }
 }
 
 
+func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
+	for i := len(ms) - 1; i >= 0; i-- {
+		rs = WithMiddleware(ms[i], rs...)
+	}
+	return rs
+}
+
 func WithMiddleware(middleware Middleware, rs ...Route) []Route {
 func WithMiddleware(middleware Middleware, rs ...Route) []Route {
 	routes := make([]Route, len(rs))
 	routes := make([]Route, len(rs))
 
 

+ 72 - 0
rest/server_test.go

@@ -68,3 +68,75 @@ func TestWithMiddleware(t *testing.T) {
 		"wan":   "2020",
 		"wan":   "2020",
 	}, m)
 	}, m)
 }
 }
+
+func TestMultiMiddleware(t *testing.T) {
+	m := make(map[string]string)
+	router := router.NewPatRouter()
+	handler := func(w http.ResponseWriter, r *http.Request) {
+		var v struct {
+			Nickname string `form:"nickname"`
+			Zipcode  int64  `form:"zipcode"`
+		}
+
+		err := httpx.Parse(r, &v)
+		assert.Nil(t, err)
+		_, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
+		assert.Nil(t, err)
+	}
+	rs := WithMiddlewares([]Middleware{
+		func(next http.HandlerFunc) http.HandlerFunc {
+			return func(w http.ResponseWriter, r *http.Request) {
+				var v struct {
+					Name string `path:"name"`
+					Year string `path:"year"`
+				}
+				assert.Nil(t, httpx.ParsePath(r, &v))
+				m[v.Name] = v.Year
+				next.ServeHTTP(w, r)
+			}
+		},
+		func(next http.HandlerFunc) http.HandlerFunc {
+			return func(w http.ResponseWriter, r *http.Request) {
+				var v struct {
+					Name    string `form:"nickname"`
+					Zipcode string `form:"zipcode"`
+				}
+				assert.Nil(t, httpx.ParseForm(r, &v))
+				assert.NotEmpty(t, m)
+				m[v.Name] = v.Zipcode + v.Zipcode
+				next.ServeHTTP(w, r)
+			}
+		},
+	}, Route{
+		Method:  http.MethodGet,
+		Path:    "/first/:name/:year",
+		Handler: handler,
+	}, Route{
+		Method:  http.MethodGet,
+		Path:    "/second/:name/:year",
+		Handler: handler,
+	})
+
+	urls := []string{
+		"http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
+		"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
+	}
+	for _, route := range rs {
+		assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
+	}
+	for _, url := range urls {
+		r, err := http.NewRequest(http.MethodGet, url, nil)
+		assert.Nil(t, err)
+
+		rr := httptest.NewRecorder()
+		router.ServeHTTP(rr, r)
+
+		assert.Equal(t, "whatever:200000200000", rr.Body.String())
+	}
+
+	assert.EqualValues(t, map[string]string{
+		"kevin":    "2017",
+		"wan":      "2020",
+		"whatever": "200000200000",
+	}, m)
+}

+ 23 - 3
tools/goctl/api/gogen/genroutes.go

@@ -31,9 +31,9 @@ func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) {
 }
 }
 `
 `
 	routesAdditionTemplate = `
 	routesAdditionTemplate = `
-	engine.AddRoutes([]rest.Route{
+	engine.AddRoutes(
 		{{.routes}}
 		{{.routes}}
-	}{{.jwt}}{{.signature}})
+	{{.jwt}}{{.signature}})
 `
 `
 )
 )
 
 
@@ -52,6 +52,7 @@ type (
 		jwtEnabled       bool
 		jwtEnabled       bool
 		signatureEnabled bool
 		signatureEnabled bool
 		authName         string
 		authName         string
+		middleware       []string
 	}
 	}
 	route struct {
 	route struct {
 		method  string
 		method  string
@@ -87,8 +88,22 @@ func genRoutes(dir string, api *spec.ApiSpec, force bool) error {
 		if g.signatureEnabled {
 		if g.signatureEnabled {
 			signature = fmt.Sprintf(", rest.WithSignature(serverCtx.Config.%s.Signature)", g.authName)
 			signature = fmt.Sprintf(", rest.WithSignature(serverCtx.Config.%s.Signature)", g.authName)
 		}
 		}
+
+		var routes string
+		if len(g.middleware) > 0 {
+			var params = g.middleware
+			for i := range params {
+				params[i] = "serverCtx." + params[i]
+			}
+			var middlewareStr = strings.Join(params, ", ")
+			routes = fmt.Sprintf("rest.WithMultiMiddleware([]rest.Middleware{ %s }, []rest.Route{\n %s \n}),",
+				middlewareStr, strings.TrimSpace(gbuilder.String()))
+		} else {
+			routes = fmt.Sprintf("[]rest.Route{\n %s \n},", strings.TrimSpace(gbuilder.String()))
+		}
+
 		if err := gt.Execute(&builder, map[string]string{
 		if err := gt.Execute(&builder, map[string]string{
-			"routes":    strings.TrimSpace(gbuilder.String()),
+			"routes":    routes,
 			"jwt":       jwt,
 			"jwt":       jwt,
 			"signature": signature,
 			"signature": signature,
 		}); err != nil {
 		}); err != nil {
@@ -185,6 +200,11 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
 			groupedRoutes.authName = value
 			groupedRoutes.authName = value
 			groupedRoutes.jwtEnabled = true
 			groupedRoutes.jwtEnabled = true
 		}
 		}
+		if value, ok := apiutil.GetAnnotationValue(g.Annotations, "server", "middleware"); ok {
+			for _, item := range strings.Split(value, ",") {
+				groupedRoutes.middleware = append(groupedRoutes.middleware, item)
+			}
+		}
 		routes = append(routes, groupedRoutes)
 		routes = append(routes, groupedRoutes)
 	}
 	}
 
 

+ 15 - 1
tools/goctl/api/gogen/gensvc.go

@@ -9,16 +9,20 @@ import (
 	"github.com/tal-tech/go-zero/tools/goctl/api/util"
 	"github.com/tal-tech/go-zero/tools/goctl/api/util"
 	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
 	ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/vars"
 )
 )
 
 
 const (
 const (
 	contextFilename = "servicecontext.go"
 	contextFilename = "servicecontext.go"
 	contextTemplate = `package svc
 	contextTemplate = `package svc
 
 
-import {{.configImport}}
+import (
+	{{.configImport}}
+)
 
 
 type ServiceContext struct {
 type ServiceContext struct {
 	Config {{.config}}
 	Config {{.config}}
+	{{.middleware}}
 }
 }
 
 
 func NewServiceContext(c {{.config}}) *ServiceContext {
 func NewServiceContext(c {{.config}}) *ServiceContext {
@@ -53,12 +57,22 @@ func genServiceContext(dir string, api *spec.ApiSpec) error {
 		return err
 		return err
 	}
 	}
 
 
+	var middlewareStr string
+	for _, item := range getMiddleware(api) {
+		middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
+	}
+
 	var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
 	var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
+	if len(middlewareStr) > 0 {
+		configImport += fmt.Sprintf("\n\"%s/rest\"", vars.ProjectOpenSourceUrl)
+	}
+
 	t := template.Must(template.New("contextTemplate").Parse(text))
 	t := template.Must(template.New("contextTemplate").Parse(text))
 	buffer := new(bytes.Buffer)
 	buffer := new(bytes.Buffer)
 	err = t.Execute(buffer, map[string]string{
 	err = t.Execute(buffer, map[string]string{
 		"configImport": configImport,
 		"configImport": configImport,
 		"config":       "config.Config",
 		"config":       "config.Config",
+		"middleware":   middlewareStr,
 	})
 	})
 	if err != nil {
 	if err != nil {
 		return nil
 		return nil

+ 12 - 0
tools/goctl/api/gogen/util.go

@@ -66,6 +66,18 @@ func getAuths(api *spec.ApiSpec) []string {
 	return authNames.KeysStr()
 	return authNames.KeysStr()
 }
 }
 
 
+func getMiddleware(api *spec.ApiSpec) []string {
+	result := collection.NewSet()
+	for _, g := range api.Service.Groups {
+		if value, ok := util.GetAnnotationValue(g.Annotations, "server", "middleware"); ok {
+			for _, item := range strings.Split(value, ",") {
+				result.Add(strings.TrimSpace(item))
+			}
+		}
+	}
+	return result.KeysStr()
+}
+
 func formatCode(code string) string {
 func formatCode(code string) string {
 	ret, err := goformat.Source([]byte(code))
 	ret, err := goformat.Source([]byte(code))
 	if err != nil {
 	if err != nil {

+ 31 - 0
tools/goctl/api/parser/parser_test.go

@@ -119,6 +119,24 @@ service A-api {
 }
 }
 `
 `
 
 
+const apiHasMiddleware = `
+type Request struct {
+  Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
+}
+
+type Response struct {
+  Message string ` + "`" + `json:"message"` + "`" + `
+}
+
+@server(
+	middleware: TokenValidate
+)
+service A-api {
+  @handler GreetHandler
+  get /greet/from/:name(Request) returns (Response)
+}
+`
+
 func TestParser(t *testing.T) {
 func TestParser(t *testing.T) {
 	filename := "greet.api"
 	filename := "greet.api"
 	err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm)
 	err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm)
@@ -198,3 +216,16 @@ func TestAnonymousAnnotation(t *testing.T) {
 	assert.Equal(t, len(api.Service.Routes), 1)
 	assert.Equal(t, len(api.Service.Routes), 1)
 	assert.Equal(t, api.Service.Routes[0].Annotations[0].Value, "GreetHandler")
 	assert.Equal(t, api.Service.Routes[0].Annotations[0].Value, "GreetHandler")
 }
 }
+
+func TestApiHasMiddleware(t *testing.T) {
+	filename := "greet.api"
+	err := ioutil.WriteFile(filename, []byte(apiHasMiddleware), os.ModePerm)
+	assert.Nil(t, err)
+	defer os.Remove(filename)
+
+	parser, err := NewParser(filename)
+	assert.Nil(t, err)
+
+	_, err = parser.Parse()
+	assert.Nil(t, err)
+}