Pārlūkot izejas kodu

refactor middleware generator (#159)

* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl support import

* goctl support import

* support return ()

* revert

* refactor and rename folder to group

* remove no need

* add anonymous annotation

* optimized

* rename

* rename

* update test

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* simple logic

* optimized

* optimized generator formatted code

* optimized generator formatted code

* add more test

* refactor middleware generator

* revert test

* revert test

* revert test

* revert test

* revert test

Co-authored-by: kingxt <dream4kingxt@163.com>
kingxt 4 gadi atpakaļ
vecāks
revīzija
1c9e81aa28

+ 59 - 0
tools/goctl/api/gogen/genmiddleware.go

@@ -0,0 +1,59 @@
+package gogen
+
+import (
+	"bytes"
+	"strings"
+	"text/template"
+
+	"github.com/tal-tech/go-zero/tools/goctl/api/util"
+)
+
+var middlewareImplementCode = `
+package middleware
+
+import "net/http"
+
+type {{.name}} struct {
+}
+
+func New{{.name}}() *{{.name}} {	
+	return &{{.name}}{}
+}
+
+func (m *{{.name}})Handle(next http.HandlerFunc) http.HandlerFunc {
+	return func(w http.ResponseWriter, r *http.Request) {
+		// TODO generate middleware implement function, delete after code implementation
+
+		// Passthrough to next handler if need 
+		next(w, r)
+	}	
+}
+`
+
+func genMiddleware(dir string, middlewares []string) error {
+	for _, item := range middlewares {
+		filename := strings.TrimSuffix(strings.ToLower(item), "middleware") + "middleware" + ".go"
+		fp, created, err := util.MaybeCreateFile(dir, middlewareDir, filename)
+		if err != nil {
+			return err
+		}
+		if !created {
+			return nil
+		}
+		defer fp.Close()
+
+		name := strings.TrimSuffix(item, "Middleware") + "Middleware"
+		t := template.Must(template.New("contextTemplate").Parse(middlewareImplementCode))
+		buffer := new(bytes.Buffer)
+		err = t.Execute(buffer, map[string]string{
+			"name": strings.Title(name),
+		})
+		if err != nil {
+			return nil
+		}
+		formatCode := formatCode(buffer.String())
+		_, err = fp.WriteString(formatCode)
+		return err
+	}
+	return nil
+}

+ 11 - 14
tools/goctl/api/gogen/gensvc.go

@@ -3,6 +3,7 @@ package gogen
 import (
 	"bytes"
 	"fmt"
+	"strings"
 	"text/template"
 
 	"github.com/tal-tech/go-zero/tools/goctl/api/spec"
@@ -31,14 +32,6 @@ func NewServiceContext(c {{.config}}) *ServiceContext {
 	}
 }
 
-{{.middlewareImplement}}
-`
-	middlewareImplementCode = `func %s(next http.HandlerFunc) http.HandlerFunc {
-	return func(w http.ResponseWriter, r *http.Request) {
-		// TODO generate middleware implement function, delete after code implementation 
-	}
-}
-
 `
 )
 
@@ -70,16 +63,21 @@ func genServiceContext(dir string, api *spec.ApiSpec) error {
 
 	var middlewareStr string
 	var middlewareAssignment string
-	var middlewareImplement string
-	for _, item := range getMiddleware(api) {
+	var middlewares = getMiddleware(api)
+	err = genMiddleware(dir, middlewares)
+	if err != nil {
+		return err
+	}
+
+	for _, item := range middlewares {
 		middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
-		middlewareAssignment += fmt.Sprintf("%s: %s,\n", item, item)
-		middlewareImplement += fmt.Sprintf(middlewareImplementCode, item)
+		name := strings.TrimSuffix(item, "Middleware") + "Middleware"
+		middlewareAssignment += fmt.Sprintf("%s: %s,\n", item, fmt.Sprintf("middleware.New%s().%s", strings.Title(name), "Handle"))
 	}
 
 	var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
 	if len(middlewareStr) > 0 {
-		configImport += "\n\t\"net/http\""
+		configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\""
 		configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceUrl)
 	}
 
@@ -90,7 +88,6 @@ func genServiceContext(dir string, api *spec.ApiSpec) error {
 		"config":               "config.Config",
 		"middleware":           middlewareStr,
 		"middlewareAssignment": middlewareAssignment,
-		"middlewareImplement":  middlewareImplement,
 	})
 	if err != nil {
 		return nil

+ 1 - 0
tools/goctl/api/gogen/vars.go

@@ -7,6 +7,7 @@ const (
 	contextDir    = interval + "svc"
 	handlerDir    = interval + "handler"
 	logicDir      = interval + "logic"
+	middlewareDir = interval + "middleware"
 	typesDir      = interval + typesPacket
 	groupProperty = "group"
 )