Selaa lähdekoodia

refine rpc generator

kevin 4 vuotta sitten
vanhempi
sitoutus
f411178a4f

+ 3 - 6
tools/goctl/api/parser/util.go

@@ -9,17 +9,14 @@ import (
 	"github.com/tal-tech/go-zero/tools/goctl/api/spec"
 )
 
-const (
-	// struct匹配
-	typeRegex = `(?m)(?m)(^ *type\s+[a-zA-Z][a-zA-Z0-9_-]+\s+(((struct)\s*?\{[\w\W]*?[^\{]\})|([a-zA-Z][a-zA-Z0-9_-]+)))|(^ *type\s*?\([\w\W]+\}\s*\))`
-)
+// struct匹配
+const typeRegex = `(?m)(?m)(^ *type\s+[a-zA-Z][a-zA-Z0-9_-]+\s+(((struct)\s*?\{[\w\W]*?[^\{]\})|([a-zA-Z][a-zA-Z0-9_-]+)))|(^ *type\s*?\([\w\W]+\}\s*\))`
 
 var (
 	emptyStrcut = errors.New("struct body not found")
+	emptyType   spec.Type
 )
 
-var emptyType spec.Type
-
 func GetType(api *spec.ApiSpec, t string) spec.Type {
 	for _, tp := range api.Types {
 		if tp.Name == t {

+ 1 - 1
tools/goctl/model/sql/gen/gen.go

@@ -171,7 +171,7 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
 		"types":   typesCode,
 		"new":     newCode,
 		"insert":  insertCode,
-		"find":    strings.Join(findCode, "\r\n"),
+		"find":    strings.Join(findCode, "\n"),
 		"update":  updateCode,
 		"delete":  deleteCode,
 	})

+ 1 - 1
tools/goctl/model/sql/gen/vars.go

@@ -20,7 +20,7 @@ func genVars(table Table, withCache bool) (string, error) {
 		Execute(map[string]interface{}{
 			"lowerStartCamelObject": stringx.From(camel).UnTitle(),
 			"upperStartCamelObject": camel,
-			"cacheKeys":             strings.Join(keys, "\r\n"),
+			"cacheKeys":             strings.Join(keys, "\n"),
 			"autoIncrement":         table.PrimaryKey.AutoIncrement,
 			"originalPrimaryKey":    table.PrimaryKey.Name.Source(),
 			"withCache":             withCache,

+ 0 - 3
tools/goctl/rpc/ctx/project.go

@@ -1,7 +1,6 @@
 package ctx
 
 import (
-	"errors"
 	"fmt"
 	"io/ioutil"
 	"os"
@@ -16,8 +15,6 @@ import (
 	"github.com/tal-tech/go-zero/tools/goctl/util/console"
 )
 
-var errProtobufNotFound = errors.New("github.com/golang/protobuf is not found,please ensure you has already [go get github.com/golang/protobuf]")
-
 const (
 	constGo          = "go"
 	constProtoC      = "protoc"

+ 1 - 1
tools/goctl/rpc/gen/gen.go

@@ -70,7 +70,7 @@ func (g *defaultRpcGenerator) Generate() (err error) {
 		return
 	}
 
-	err = g.genRemoteHandler()
+	err = g.genHandler()
 	if err != nil {
 		return
 	}

+ 37 - 18
tools/goctl/rpc/gen/genhandler.go

@@ -9,15 +9,29 @@ import (
 )
 
 const (
-	remoteTemplate = `{{.head}}
+	handlerTemplate = `{{.head}}
 
 package handler
 
-import {{.imports}}
+import (
+	"context"
+
+	{{.imports}}
+)
 
 type {{.types}}
 
-{{.newFuncs}}
+func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
+	return &{{.server}}Server{
+		svcCtx: svcCtx,
+	}
+}
+
+{{if .hasComment}}{{.comment}}{{end}}
+func (s *{{.server}}Server) {{.method}} (ctx context.Context, in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) {
+	l := logic.New{{.logicName}}(ctx,s.svcCtx)
+	return l.{{.method}}(in)
+}
 `
 	functionTemplate = `{{.head}}
 
@@ -29,8 +43,6 @@ import (
 	{{.imports}}
 )
 
-type {{.server}}Server struct{}
-
 {{if .hasComment}}{{.comment}}{{end}}
 func (s *{{.server}}Server) {{.method}} (ctx context.Context, in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) {
 	l := logic.New{{.logicName}}(ctx,s.svcCtx)
@@ -47,29 +59,35 @@ func (s *{{.server}}Server) {{.method}} (ctx context.Context, in *{{.package}}.{
 }`
 )
 
-func (g *defaultRpcGenerator) genRemoteHandler() error {
+func (g *defaultRpcGenerator) genHandler() error {
 	handlerPath := g.dirM[dirHandler]
-	serverGo := fmt.Sprintf("%vhandler.go", g.Ctx.ServiceName.Lower())
-	fileName := filepath.Join(handlerPath, serverGo)
+	filename := fmt.Sprintf("%vhandler.go", g.Ctx.ServiceName.Lower())
+	handlerFile := filepath.Join(handlerPath, filename)
 	file := g.ast
+	pkg := file.Package
+	pbImport := fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb))
+	logicImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic))
 	svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
+	imports := []string{
+		pbImport,
+		logicImport,
+		svcImport,
+	}
 	types := make([]string, 0)
 	newFuncs := make([]string, 0)
 	head := util.GetHead(g.Ctx.ProtoSource)
 	for _, service := range file.Service {
 		types = append(types, fmt.Sprintf(typeFmt, service.Name.Title()))
-		newFuncs = append(newFuncs, fmt.Sprintf(newFuncFmt, service.Name.Title(), service.Name.Title(), service.Name.Title()))
+		newFuncs = append(newFuncs, fmt.Sprintf(newFuncFmt, service.Name.Title(),
+			service.Name.Title(), service.Name.Title()))
 	}
-	err := util.With("server").GoFmt(true).Parse(remoteTemplate).SaveTo(map[string]interface{}{
+
+	return util.With("server").GoFmt(true).Parse(handlerTemplate).SaveTo(map[string]interface{}{
 		"head":     head,
 		"types":    strings.Join(types, "\n"),
 		"newFuncs": strings.Join(newFuncs, "\n"),
-		"imports":  svcImport,
-	}, fileName, true)
-	if err != nil {
-		return err
-	}
-	return g.genFunctions()
+		"imports":  strings.Join(imports, "\n\t"),
+	}, handlerFile, true)
 }
 
 func (g *defaultRpcGenerator) genFunctions() error {
@@ -89,19 +107,20 @@ func (g *defaultRpcGenerator) genFunctions() error {
 			err := util.With("func").GoFmt(true).Parse(functionTemplate).SaveTo(map[string]interface{}{
 				"head":       head,
 				"server":     service.Name.Title(),
-				"imports":    strings.Join(handlerImports, "\r\n"),
+				"imports":    strings.Join(handlerImports, "\n"),
 				"logicName":  fmt.Sprintf("%sLogic", method.Name.Title()),
 				"method":     method.Name.Title(),
 				"package":    pkg,
 				"request":    method.InType,
 				"response":   method.OutType,
 				"hasComment": len(method.Document),
-				"comment":    strings.Join(method.Document, "\r\n"),
+				"comment":    strings.Join(method.Document, "\n"),
 			}, filename, true)
 			if err != nil {
 				return err
 			}
 		}
 	}
+
 	return nil
 }

+ 2 - 2
tools/goctl/rpc/gen/genlogic.go

@@ -64,7 +64,7 @@ func (g *defaultRpcGenerator) genLogic() error {
 			err = util.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{
 				"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
 				"functions": functions,
-				"imports":   strings.Join(imports.KeysStr(), "\r\n"),
+				"imports":   strings.Join(imports.KeysStr(), "\n"),
 			}, filename, false)
 			if err != nil {
 				return err
@@ -83,7 +83,7 @@ func genLogicFunction(packageName string, method *parser.Func) (string, error) {
 		"request":    method.InType,
 		"response":   method.OutType,
 		"hasComment": len(method.Document) > 0,
-		"comment":    strings.Join(method.Document, "\r\n"),
+		"comment":    strings.Join(method.Document, "\n"),
 	})
 	if err != nil {
 		return "", err

+ 1 - 1
tools/goctl/rpc/gen/genmain.go

@@ -67,7 +67,7 @@ func (g *defaultRpcGenerator) genMain() error {
 		"serviceName": g.Ctx.ServiceName.Lower(),
 		"srv":         srv,
 		"registers":   registers,
-		"imports":     strings.Join(imports, "\r\n"),
+		"imports":     strings.Join(imports, "\n"),
 	}, fileName, true)
 }