|
- package generator
- import (
- _ "embed"
- "fmt"
- "path/filepath"
- "sort"
- "strings"
- "github.com/emicklei/proto"
- "github.com/wuntsong-org/go-zero-plus/core/collection"
- conf "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/config"
- "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/rpc/parser"
- "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util"
- "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/format"
- "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/pathx"
- "github.com/wuntsong-org/go-zero-plus/tools/goctlwt/util/stringx"
- )
- const (
- callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
- {{end}}{{.method}}(ctx context.Context{{if .hasReq}}, in *{{.pbRequest}}{{end}}, opts ...grpc.CallOption) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error)`
- callFunctionTemplate = `
- {{if .hasComment}}{{.comment}}{{end}}
- func (m *default{{.serviceName}}) {{.method}}(ctx context.Context{{if .hasReq}}, in *{{.pbRequest}}{{end}}, opts ...grpc.CallOption) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error) {
- client := {{if .isCallPkgSameToGrpcPkg}}{{else}}{{.package}}.{{end}}New{{.rpcServiceName}}Client(m.cli.Conn())
- return client.{{.method}}(ctx{{if .hasReq}}, in{{end}}, opts...)
- }
- `
- )
- //go:embed call.tpl
- var callTemplateText string
- // GenCall generates the rpc client code, which is the entry point for the rpc service call.
- // It is a layer of encapsulation for the rpc client and shields the details in the pb.
- func (g *Generator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config,
- c *ZRpcContext) error {
- if !c.Multiple {
- return g.genCallInCompatibility(ctx, proto, cfg)
- }
- return g.genCallGroup(ctx, proto, cfg)
- }
- func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
- dir := ctx.GetCall()
- head := util.GetHead(proto.Name)
- for _, service := range proto.Service {
- childPkg, err := dir.GetChildPackage(service.Name)
- if err != nil {
- return err
- }
- callFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name)
- if err != nil {
- return err
- }
- childDir := filepath.Base(childPkg)
- filename := filepath.Join(dir.Filename, childDir, fmt.Sprintf("%s.go", callFilename))
- isCallPkgSameToPbPkg := childDir == ctx.GetProtoGo().Filename
- isCallPkgSameToGrpcPkg := childDir == ctx.GetProtoGo().Filename
- serviceName := stringx.From(service.Name).ToCamel()
- alias := collection.NewSet()
- var hasSameNameBetweenMessageAndService bool
- for _, item := range proto.Message {
- msgName := getMessageName(*item.Message)
- if serviceName == msgName {
- hasSameNameBetweenMessageAndService = true
- }
- if !isCallPkgSameToPbPkg {
- alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
- fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
- }
- }
- if hasSameNameBetweenMessageAndService {
- serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
- }
- functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
- if err != nil {
- return err
- }
- iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
- if err != nil {
- return err
- }
- text, err := pathx.LoadTemplate(category, callTemplateFile, callTemplateText)
- if err != nil {
- return err
- }
- pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
- protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
- if isCallPkgSameToGrpcPkg {
- pbPackage = ""
- protoGoPackage = ""
- }
- aliasKeys := alias.KeysStr()
- sort.Strings(aliasKeys)
- if err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]any{
- "name": callFilename,
- "alias": strings.Join(aliasKeys, pathx.NL),
- "head": head,
- "filePackage": childDir,
- "pbPackage": pbPackage,
- "protoGoPackage": protoGoPackage,
- "serviceName": serviceName,
- "functions": strings.Join(functions, pathx.NL),
- "interface": strings.Join(iFunctions, pathx.NL),
- }, filename, true); err != nil {
- return err
- }
- }
- return nil
- }
- func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
- cfg *conf.Config) error {
- dir := ctx.GetCall()
- service := proto.Service[0]
- head := util.GetHead(proto.Name)
- isCallPkgSameToPbPkg := ctx.GetCall().Filename == ctx.GetPb().Filename
- isCallPkgSameToGrpcPkg := ctx.GetCall().Filename == ctx.GetProtoGo().Filename
- callFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name)
- if err != nil {
- return err
- }
- serviceName := stringx.From(service.Name).ToCamel()
- alias := collection.NewSet()
- var hasSameNameBetweenMessageAndService bool
- for _, item := range proto.Message {
- msgName := getMessageName(*item.Message)
- if serviceName == msgName {
- hasSameNameBetweenMessageAndService = true
- }
- if !isCallPkgSameToPbPkg {
- alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
- fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
- }
- }
- if hasSameNameBetweenMessageAndService {
- serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
- }
- filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", callFilename))
- functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
- if err != nil {
- return err
- }
- iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
- if err != nil {
- return err
- }
- text, err := pathx.LoadTemplate(category, callTemplateFile, callTemplateText)
- if err != nil {
- return err
- }
- pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
- protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
- if isCallPkgSameToGrpcPkg {
- pbPackage = ""
- protoGoPackage = ""
- }
- aliasKeys := alias.KeysStr()
- sort.Strings(aliasKeys)
- return util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]any{
- "name": callFilename,
- "alias": strings.Join(aliasKeys, pathx.NL),
- "head": head,
- "filePackage": dir.Base,
- "pbPackage": pbPackage,
- "protoGoPackage": protoGoPackage,
- "serviceName": serviceName,
- "functions": strings.Join(functions, pathx.NL),
- "interface": strings.Join(iFunctions, pathx.NL),
- }, filename, true)
- }
- func getMessageName(msg proto.Message) string {
- list := []string{msg.Name}
- for {
- parent := msg.Parent
- if parent == nil {
- break
- }
- parentMsg, ok := parent.(*proto.Message)
- if !ok {
- break
- }
- tmp := []string{parentMsg.Name}
- list = append(tmp, list...)
- msg = *parentMsg
- }
- return strings.Join(list, "_")
- }
- func (g *Generator) genFunction(goPackage string, serviceName string, service parser.Service,
- isCallPkgSameToGrpcPkg bool) ([]string, error) {
- functions := make([]string, 0)
- for _, rpc := range service.RPC {
- text, err := pathx.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate)
- if err != nil {
- return nil, err
- }
- comment := parser.GetComment(rpc.Doc())
- streamServer := fmt.Sprintf("%s.%s_%s%s", goPackage, parser.CamelCase(service.Name),
- parser.CamelCase(rpc.Name), "Client")
- if isCallPkgSameToGrpcPkg {
- streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name),
- parser.CamelCase(rpc.Name), "Client")
- }
- buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]any{
- "serviceName": serviceName,
- "rpcServiceName": parser.CamelCase(service.Name),
- "method": parser.CamelCase(rpc.Name),
- "package": goPackage,
- "pbRequest": parser.CamelCase(rpc.RequestType),
- "pbResponse": parser.CamelCase(rpc.ReturnsType),
- "hasComment": len(comment) > 0,
- "comment": comment,
- "hasReq": !rpc.StreamsRequest,
- "notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
- "streamBody": streamServer,
- "isCallPkgSameToGrpcPkg": isCallPkgSameToGrpcPkg,
- })
- if err != nil {
- return nil, err
- }
- functions = append(functions, buffer.String())
- }
- return functions, nil
- }
- func (g *Generator) getInterfaceFuncs(goPackage string, service parser.Service,
- isCallPkgSameToGrpcPkg bool) ([]string, error) {
- functions := make([]string, 0)
- for _, rpc := range service.RPC {
- text, err := pathx.LoadTemplate(category, callInterfaceFunctionTemplateFile,
- callInterfaceFunctionTemplate)
- if err != nil {
- return nil, err
- }
- comment := parser.GetComment(rpc.Doc())
- streamServer := fmt.Sprintf("%s.%s_%s%s", goPackage, parser.CamelCase(service.Name),
- parser.CamelCase(rpc.Name), "Client")
- if isCallPkgSameToGrpcPkg {
- streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name),
- parser.CamelCase(rpc.Name), "Client")
- }
- buffer, err := util.With("interfaceFn").Parse(text).Execute(
- map[string]any{
- "hasComment": len(comment) > 0,
- "comment": comment,
- "method": parser.CamelCase(rpc.Name),
- "hasReq": !rpc.StreamsRequest,
- "pbRequest": parser.CamelCase(rpc.RequestType),
- "notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
- "pbResponse": parser.CamelCase(rpc.ReturnsType),
- "streamBody": streamServer,
- })
- if err != nil {
- return nil, err
- }
- functions = append(functions, buffer.String())
- }
- return functions, nil
- }
|