123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- package generator
- import (
- _ "embed"
- "fmt"
- "path/filepath"
- "strings"
- "github.com/wuntsong-org/go-zero-plus/core/collection"
- conf "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/config"
- "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/rpc/parser"
- "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util"
- "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/format"
- "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/pathx"
- "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/stringx"
- )
- const functionTemplate = `
- {{if .hasComment}}{{.comment}}{{end}}
- func (s *{{.server}}Server) {{.method}} ({{if .notStream}}ctx context.Context,{{if .hasReq}} in {{.request}}{{end}}{{else}}{{if .hasReq}} in {{.request}},{{end}}stream {{.streamBody}}{{end}}) ({{if .notStream}}{{.response}},{{end}}error) {
- l := {{.logicPkg}}.New{{.logicName}}({{if .notStream}}ctx,{{else}}stream.Context(),{{end}}s.svcCtx)
- return l.{{.method}}({{if .hasReq}}in{{if .stream}} ,stream{{end}}{{else}}{{if .stream}}stream{{end}}{{end}})
- }
- `
- //go:embed server.tpl
- var serverTemplate string
- // GenServer generates rpc server file, which is an implementation of rpc server
- func (g *Generator) GenServer(ctx DirContext, proto parser.Proto, cfg *conf.Config,
- c *ZRpcContext) error {
- if !c.Multiple {
- return g.genServerInCompatibility(ctx, proto, cfg, c)
- }
- return g.genServerGroup(ctx, proto, cfg)
- }
- func (g *Generator) genServerGroup(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
- dir := ctx.GetServer()
- for _, service := range proto.Service {
- var (
- serverFile string
- logicImport string
- )
- serverFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name+"_server")
- if err != nil {
- return err
- }
- serverChildPkg, err := dir.GetChildPackage(service.Name)
- if err != nil {
- return err
- }
- logicChildPkg, err := ctx.GetLogic().GetChildPackage(service.Name)
- if err != nil {
- return err
- }
- serverDir := filepath.Base(serverChildPkg)
- logicImport = fmt.Sprintf(`"%v"`, logicChildPkg)
- serverFile = filepath.Join(dir.Filename, serverDir, serverFilename+".go")
- svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
- pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
- imports := collection.NewSet()
- imports.AddStr(logicImport, svcImport, pbImport)
- head := util.GetHead(proto.Name)
- funcList, err := g.genFunctions(proto.PbPackage, service, true)
- if err != nil {
- return err
- }
- text, err := pathx.LoadTemplate(category, serverTemplateFile, serverTemplate)
- if err != nil {
- return err
- }
- notStream := false
- for _, rpc := range service.RPC {
- if !rpc.StreamsRequest && !rpc.StreamsReturns {
- notStream = true
- break
- }
- }
- if err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]any{
- "head": head,
- "unimplementedServer": fmt.Sprintf("%s.Unimplemented%sServer", proto.PbPackage,
- stringx.From(service.Name).ToCamel()),
- "server": stringx.From(service.Name).ToCamel(),
- "imports": strings.Join(imports.KeysStr(), pathx.NL),
- "funcs": strings.Join(funcList, pathx.NL),
- "notStream": notStream,
- }, serverFile, true); err != nil {
- return err
- }
- }
- return nil
- }
- func (g *Generator) genServerInCompatibility(ctx DirContext, proto parser.Proto,
- cfg *conf.Config, c *ZRpcContext) error {
- dir := ctx.GetServer()
- logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package)
- svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
- pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
- imports := collection.NewSet()
- imports.AddStr(logicImport, svcImport, pbImport)
- head := util.GetHead(proto.Name)
- service := proto.Service[0]
- serverFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name+"_server")
- if err != nil {
- return err
- }
- serverFile := filepath.Join(dir.Filename, serverFilename+".go")
- funcList, err := g.genFunctions(proto.PbPackage, service, false)
- if err != nil {
- return err
- }
- text, err := pathx.LoadTemplate(category, serverTemplateFile, serverTemplate)
- if err != nil {
- return err
- }
- notStream := false
- for _, rpc := range service.RPC {
- if !rpc.StreamsRequest && !rpc.StreamsReturns {
- notStream = true
- break
- }
- }
- return util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]any{
- "head": head,
- "unimplementedServer": fmt.Sprintf("%s.Unimplemented%sServer", proto.PbPackage,
- stringx.From(service.Name).ToCamel()),
- "server": stringx.From(service.Name).ToCamel(),
- "imports": strings.Join(imports.KeysStr(), pathx.NL),
- "funcs": strings.Join(funcList, pathx.NL),
- "notStream": notStream,
- }, serverFile, true)
- }
- func (g *Generator) genFunctions(goPackage string, service parser.Service, multiple bool) ([]string, error) {
- var (
- functionList []string
- logicPkg string
- )
- for _, rpc := range service.RPC {
- text, err := pathx.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)
- if err != nil {
- return nil, err
- }
- var logicName string
- if !multiple {
- logicPkg = "logic"
- logicName = fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel())
- } else {
- nameJoin := fmt.Sprintf("%s_logic", service.Name)
- logicPkg = strings.ToLower(stringx.From(nameJoin).ToCamel())
- logicName = fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel())
- }
- comment := parser.GetComment(rpc.Doc())
- streamServer := fmt.Sprintf("%s.%s_%s%s", goPackage, parser.CamelCase(service.Name),
- parser.CamelCase(rpc.Name), "Server")
- buffer, err := util.With("func").Parse(text).Execute(map[string]any{
- "server": stringx.From(service.Name).ToCamel(),
- "logicName": logicName,
- "method": parser.CamelCase(rpc.Name),
- "request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
- "response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
- "hasComment": len(comment) > 0,
- "comment": comment,
- "hasReq": !rpc.StreamsRequest,
- "stream": rpc.StreamsRequest || rpc.StreamsReturns,
- "notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
- "streamBody": streamServer,
- "logicPkg": logicPkg,
- })
- if err != nil {
- return nil, err
- }
- functionList = append(functionList, buffer.String())
- }
- return functionList, nil
- }
|