kevin 4 лет назад
Родитель
Сommit
f904710811

+ 7 - 1
tools/goctl/api/gogen/genconfig.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/tal-tech/go-zero/tools/goctl/api/spec"
 	"github.com/tal-tech/go-zero/tools/goctl/api/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/vars"
 )
 
@@ -47,7 +48,12 @@ func genConfig(dir string, api *spec.ApiSpec) error {
 	}
 
 	var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl)
-	t := template.Must(template.New("configTemplate").Parse(configTemplate))
+	text, err := templatex.LoadTemplate(category, configTemplateFile, configTemplate)
+	if err != nil {
+		return err
+	}
+
+	t := template.Must(template.New("configTemplate").Parse(text))
 	buffer := new(bytes.Buffer)
 	err = t.Execute(buffer, map[string]string{
 		"authImport": authImportStr,

+ 7 - 1
tools/goctl/api/gogen/genetc.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/tal-tech/go-zero/tools/goctl/api/spec"
 	"github.com/tal-tech/go-zero/tools/goctl/api/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 )
 
 const (
@@ -39,7 +40,12 @@ func genEtc(dir string, api *spec.ApiSpec) error {
 		port = strconv.Itoa(defaultPort)
 	}
 
-	t := template.Must(template.New("etcTemplate").Parse(etcTemplate))
+	text, err := templatex.LoadTemplate(category, etcTemplateFile, etcTemplate)
+	if err != nil {
+		return err
+	}
+
+	t := template.Must(template.New("etcTemplate").Parse(text))
 	buffer := new(bytes.Buffer)
 	err = t.Execute(buffer, map[string]string{
 		"serviceName": service.Name,

+ 43 - 80
tools/goctl/api/gogen/genhandlers.go

@@ -9,115 +9,76 @@ import (
 
 	"github.com/tal-tech/go-zero/tools/goctl/api/spec"
 	apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 	"github.com/tal-tech/go-zero/tools/goctl/vars"
 )
 
-const (
-	handlerTemplate = `package handler
+const handlerTemplate = `package handler
 
 import (
 	"net/http"
 
-	{{.importPackages}}
+	{{.ImportPackages}}
 )
 
-func {{.handlerName}}(ctx *svc.ServiceContext) http.HandlerFunc {
+func {{.HandlerName}}(ctx *svc.ServiceContext) http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
-		{{.handlerBody}}
-	}
-}
-`
-	handlerBodyTemplate = `{{.parseRequest}}
-		{{.processBody}}
-`
-	parseRequestTemplate = `var req {{.requestType}}
+		var req types.{{.RequestType}}
 		if err := httpx.Parse(r, &req); err != nil {
 			httpx.Error(w, err)
 			return
 		}
-`
-	hasRespTemplate = `
-		l := logic.{{.logic}}(r.Context(), ctx)
-		{{.logicResponse}} l.{{.callee}}({{.req}})
+
+		l := logic.New{{.LogicType}}(r.Context(), ctx)
+		{{if .HasResp}}resp, {{end}}err := l.{{.Call}}(req)
 		if err != nil {
 			httpx.Error(w, err)
 		} else {
-			{{.respWriter}}
+			{{if .HasResp}}httpx.OkJson(w, resp){{else}}httpx.Ok(w){{end}}
 		}
-	`
-)
+	}
+}
+`
+
+type Handler struct {
+	ImportPackages string
+	HandlerName    string
+	RequestType    string
+	LogicType      string
+	Call           string
+	HasResp        bool
+}
 
 func genHandler(dir string, group spec.Group, route spec.Route) error {
 	handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler")
 	if !ok {
 		return fmt.Errorf("missing handler annotation for %q", route.Path)
 	}
-	handler = getHandlerName(handler)
-	var reqBody string
-	if len(route.RequestType.Name) > 0 {
-		var bodyBuilder strings.Builder
-		t := template.Must(template.New("parseRequest").Parse(parseRequestTemplate))
-		if err := t.Execute(&bodyBuilder, map[string]string{
-			"requestType": typesPacket + "." + util.Title(route.RequestType.Name),
-		}); err != nil {
-			return err
-		}
-		reqBody = bodyBuilder.String()
-	}
 
-	var req = "req"
-	if len(route.RequestType.Name) == 0 {
-		req = ""
-	}
-	var logicResponse string
-	var writeResponse string
-	var respWriter = `httpx.WriteJson(w, http.StatusOK, resp)`
-	if len(route.ResponseType.Name) > 0 {
-		logicResponse = "resp, err :="
-		writeResponse = "resp, err"
-	} else {
-		logicResponse = "err :="
-		writeResponse = "nil, err"
-		respWriter = `httpx.Ok(w)`
+	handler = getHandlerName(handler)
+	if getHandlerFolderPath(group, route) != handlerDir {
+		handler = strings.Title(handler)
 	}
-	var logicBodyBuilder strings.Builder
-	t := template.Must(template.New("hasRespTemplate").Parse(hasRespTemplate))
-	if err := t.Execute(&logicBodyBuilder, map[string]string{
-		"logic":         "New" + strings.TrimSuffix(strings.Title(handler), "Handler") + "Logic",
-		"callee":        strings.Title(strings.TrimSuffix(handler, "Handler")),
-		"req":           req,
-		"logicResponse": logicResponse,
-		"writeResponse": writeResponse,
-		"respWriter":    respWriter,
-	}); err != nil {
+	parentPkg, err := getParentPackage(dir)
+	if err != nil {
 		return err
 	}
-	respBody := logicBodyBuilder.String()
-
-	if !strings.HasSuffix(handler, "Handler") {
-		handler = handler + "Handler"
-	}
 
-	var bodyBuilder strings.Builder
-	bodyTemplate := template.Must(template.New("handlerBodyTemplate").Parse(handlerBodyTemplate))
-	if err := bodyTemplate.Execute(&bodyBuilder, map[string]string{
-		"parseRequest": reqBody,
-		"processBody":  respBody,
-	}); err != nil {
-		return err
-	}
-	return doGenToFile(dir, handler, group, route, bodyBuilder)
+	return doGenToFile(dir, handler, group, route, Handler{
+		ImportPackages: genHandlerImports(group, route, parentPkg),
+		HandlerName:    handler,
+		RequestType:    util.Title(route.RequestType.Name),
+		LogicType:      strings.TrimSuffix(strings.Title(handler), "Handler") + "Logic",
+		Call:           strings.Title(strings.TrimSuffix(handler, "Handler")),
+		HasResp:        len(route.ResponseType.Name) > 0,
+	})
 }
 
-func doGenToFile(dir, handler string, group spec.Group, route spec.Route, bodyBuilder strings.Builder) error {
+func doGenToFile(dir, handler string, group spec.Group, route spec.Route, handleObj Handler) error {
 	if getHandlerFolderPath(group, route) != handlerDir {
 		handler = strings.Title(handler)
 	}
-	parentPkg, err := getParentPackage(dir)
-	if err != nil {
-		return err
-	}
 	filename := strings.ToLower(handler)
 	if strings.HasSuffix(filename, "handler") {
 		filename = filename + ".go"
@@ -132,16 +93,18 @@ func doGenToFile(dir, handler string, group spec.Group, route spec.Route, bodyBu
 		return nil
 	}
 	defer fp.Close()
-	t := template.Must(template.New("handlerTemplate").Parse(handlerTemplate))
+
+	text, err := templatex.LoadTemplate(category, handlerTemplateFile, handlerTemplate)
+	if err != nil {
+		return err
+	}
+
 	buffer := new(bytes.Buffer)
-	err = t.Execute(buffer, map[string]string{
-		"importPackages": genHandlerImports(group, route, parentPkg),
-		"handlerName":    handler,
-		"handlerBody":    strings.TrimSpace(bodyBuilder.String()),
-	})
+	err = template.Must(template.New("handlerTemplate").Parse(text)).Execute(buffer, handleObj)
 	if err != nil {
 		return nil
 	}
+
 	formatCode := formatCode(buffer.String())
 	_, err = fp.WriteString(formatCode)
 	return err

+ 7 - 1
tools/goctl/api/gogen/genlogic.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/tal-tech/go-zero/tools/goctl/api/spec"
 	"github.com/tal-tech/go-zero/tools/goctl/api/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
 	"github.com/tal-tech/go-zero/tools/goctl/vars"
 )
@@ -93,7 +94,12 @@ func genLogicByRoute(dir string, group spec.Group, route spec.Route) error {
 		requestString = "req " + "types." + strings.Title(route.RequestType.Name)
 	}
 
-	t := template.Must(template.New("logicTemplate").Parse(logicTemplate))
+	text, err := templatex.LoadTemplate(category, logicTemplateFile, logicTemplate)
+	if err != nil {
+		return err
+	}
+
+	t := template.Must(template.New("logicTemplate").Parse(text))
 	buffer := new(bytes.Buffer)
 	err = t.Execute(fp, map[string]string{
 		"imports":      imports,

+ 7 - 1
tools/goctl/api/gogen/genmain.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/tal-tech/go-zero/tools/goctl/api/spec"
 	"github.com/tal-tech/go-zero/tools/goctl/api/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
 	"github.com/tal-tech/go-zero/tools/goctl/vars"
 )
@@ -60,7 +61,12 @@ func genMain(dir string, api *spec.ApiSpec) error {
 		return err
 	}
 
-	t := template.Must(template.New("mainTemplate").Parse(mainTemplate))
+	text, err := templatex.LoadTemplate(category, mainTemplateFile, mainTemplate)
+	if err != nil {
+		return err
+	}
+
+	t := template.Must(template.New("mainTemplate").Parse(text))
 	buffer := new(bytes.Buffer)
 	err = t.Execute(buffer, map[string]string{
 		"importPackages": genMainImports(parentPkg),

+ 8 - 1
tools/goctl/api/gogen/gensvc.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/tal-tech/go-zero/tools/goctl/api/spec"
 	"github.com/tal-tech/go-zero/tools/goctl/api/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
 )
 
@@ -46,8 +47,14 @@ func genServiceContext(dir string, api *spec.ApiSpec) error {
 	if err != nil {
 		return err
 	}
+
+	text, err := templatex.LoadTemplate(category, contextTemplateFile, contextTemplate)
+	if err != nil {
+		return err
+	}
+
 	var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
-	t := template.Must(template.New("contextTemplate").Parse(contextTemplate))
+	t := template.Must(template.New("contextTemplate").Parse(text))
 	buffer := new(bytes.Buffer)
 	err = t.Execute(buffer, map[string]string{
 		"configImport": configImport,

+ 29 - 0
tools/goctl/api/gogen/template.go

@@ -0,0 +1,29 @@
+package gogen
+
+import (
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
+	"github.com/urfave/cli"
+)
+
+const (
+	category            = "api"
+	configTemplateFile  = "config.tpl"
+	contextTemplateFile = "context.tpl"
+	etcTemplateFile     = "etc.tpl"
+	handlerTemplateFile = "handler.tpl"
+	logicTemplateFile   = "logic.tpl"
+	mainTemplateFile    = "main.tpl"
+)
+
+var templates = map[string]string{
+	configTemplateFile:  configTemplate,
+	contextTemplateFile: contextTemplate,
+	etcTemplateFile:     etcTemplate,
+	handlerTemplateFile: handlerTemplate,
+	logicTemplateFile:   logicTemplate,
+	mainTemplateFile:    mainTemplate,
+}
+
+func GenTemplates(_ *cli.Context) error {
+	return templatex.InitTemplates(category, templates)
+}

+ 7 - 0
tools/goctl/goctl.go

@@ -102,6 +102,13 @@ var (
 						},
 					},
 					Action: gogen.GoCommand,
+					Subcommands: []cli.Command{
+						{
+							Name:   "template",
+							Usage:  "initialize the api templates",
+							Action: gogen.GenTemplates,
+						},
+					},
 				},
 				{
 					Name:  "java",

+ 2 - 2
tools/goctl/model/sql/gen/delete.go

@@ -5,7 +5,7 @@ import (
 
 	"github.com/tal-tech/go-zero/core/collection"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
 
@@ -22,7 +22,7 @@ func genDelete(table Table, withCache bool) (string, error) {
 	}
 
 	camel := table.Name.ToCamel()
-	output, err := util.With("delete").
+	output, err := templatex.With("delete").
 		Parse(template.Delete).
 		Execute(map[string]interface{}{
 			"upperStartCamelObject":     camel,

+ 2 - 2
tools/goctl/model/sql/gen/field.go

@@ -5,7 +5,7 @@ import (
 
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 )
 
 func genFields(fields []parser.Field) (string, error) {
@@ -25,7 +25,7 @@ func genField(field parser.Field) (string, error) {
 	if err != nil {
 		return "", err
 	}
-	output, err := util.With("types").
+	output, err := templatex.With("types").
 		Parse(template.Field).
 		Execute(map[string]interface{}{
 			"name":       field.Name.ToCamel(),

+ 2 - 2
tools/goctl/model/sql/gen/findone.go

@@ -2,13 +2,13 @@ package gen
 
 import (
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
 
 func genFindOne(table Table, withCache bool) (string, error) {
 	camel := table.Name.ToCamel()
-	output, err := util.With("findOne").
+	output, err := templatex.With("findOne").
 		Parse(template.FindOne).
 		Execute(map[string]interface{}{
 			"withCache":                 withCache,

+ 3 - 3
tools/goctl/model/sql/gen/findonebyfield.go

@@ -5,12 +5,12 @@ import (
 	"strings"
 
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
 
 func genFindOneByField(table Table, withCache bool) (string, string, error) {
-	t := util.With("findOneByField").Parse(template.FindOneByField)
+	t := templatex.With("findOneByField").Parse(template.FindOneByField)
 	var list []string
 	camelTableName := table.Name.ToCamel()
 	for _, field := range table.Fields {
@@ -36,7 +36,7 @@ func genFindOneByField(table Table, withCache bool) (string, string, error) {
 		list = append(list, output.String())
 	}
 	if withCache {
-		out, err := util.With("findOneByFieldExtraMethod").Parse(template.FindOneByFieldExtraMethod).Execute(map[string]interface{}{
+		out, err := templatex.With("findOneByFieldExtraMethod").Parse(template.FindOneByFieldExtraMethod).Execute(map[string]interface{}{
 			"upperStartCamelObject": camelTableName,
 			"primaryKeyLeft":        table.CacheKey[table.PrimaryKey.Name.Source()].Left,
 			"lowerStartCamelObject": stringx.From(camelTableName).UnTitle(),

+ 2 - 1
tools/goctl/model/sql/gen/gen.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 	"github.com/tal-tech/go-zero/tools/goctl/util/console"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
@@ -119,7 +120,7 @@ type (
 )
 
 func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) {
-	t := util.With("model").
+	t := templatex.With("model").
 		Parse(template.Model).
 		GoFmt(true)
 

+ 3 - 3
tools/goctl/model/sql/gen/imports.go

@@ -2,12 +2,12 @@ package gen
 
 import (
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 )
 
 func genImports(withCache, timeImport bool) (string, error) {
 	if withCache {
-		buffer, err := util.With("import").Parse(template.Imports).Execute(map[string]interface{}{
+		buffer, err := templatex.With("import").Parse(template.Imports).Execute(map[string]interface{}{
 			"time": timeImport,
 		})
 		if err != nil {
@@ -15,7 +15,7 @@ func genImports(withCache, timeImport bool) (string, error) {
 		}
 		return buffer.String(), nil
 	} else {
-		buffer, err := util.With("import").Parse(template.ImportsNoCache).Execute(map[string]interface{}{
+		buffer, err := templatex.With("import").Parse(template.ImportsNoCache).Execute(map[string]interface{}{
 			"time": timeImport,
 		})
 		if err != nil {

+ 2 - 2
tools/goctl/model/sql/gen/insert.go

@@ -5,7 +5,7 @@ import (
 
 	"github.com/tal-tech/go-zero/core/collection"
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
 
@@ -34,7 +34,7 @@ func genInsert(table Table, withCache bool) (string, error) {
 		expressionValues = append(expressionValues, "data."+camel)
 	}
 	camel := table.Name.ToCamel()
-	output, err := util.With("insert").
+	output, err := templatex.With("insert").
 		Parse(template.Insert).
 		Execute(map[string]interface{}{
 			"withCache":             withCache,

+ 2 - 2
tools/goctl/model/sql/gen/new.go

@@ -2,11 +2,11 @@ package gen
 
 import (
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 )
 
 func genNew(table Table, withCache bool) (string, error) {
-	output, err := util.With("new").
+	output, err := templatex.With("new").
 		Parse(template.New).
 		Execute(map[string]interface{}{
 			"withCache":             withCache,

+ 2 - 2
tools/goctl/model/sql/gen/tag.go

@@ -2,14 +2,14 @@ package gen
 
 import (
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 )
 
 func genTag(in string) (string, error) {
 	if in == "" {
 		return in, nil
 	}
-	output, err := util.With("tag").
+	output, err := templatex.With("tag").
 		Parse(template.Tag).
 		Execute(map[string]interface{}{
 			"field": in,

+ 2 - 2
tools/goctl/model/sql/gen/types.go

@@ -2,7 +2,7 @@ package gen
 
 import (
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 )
 
 func genTypes(table Table, withCache bool) (string, error) {
@@ -11,7 +11,7 @@ func genTypes(table Table, withCache bool) (string, error) {
 	if err != nil {
 		return "", err
 	}
-	output, err := util.With("types").
+	output, err := templatex.With("types").
 		Parse(template.Types).
 		Execute(map[string]interface{}{
 			"withCache":             withCache,

+ 2 - 2
tools/goctl/model/sql/gen/update.go

@@ -4,7 +4,7 @@ import (
 	"strings"
 
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
 
@@ -22,7 +22,7 @@ func genUpdate(table Table, withCache bool) (string, error) {
 	}
 	expressionValues = append(expressionValues, "data."+table.PrimaryKey.Name.ToCamel())
 	camelTableName := table.Name.ToCamel()
-	output, err := util.With("update").
+	output, err := templatex.With("update").
 		Parse(template.Update).
 		Execute(map[string]interface{}{
 			"withCache":             withCache,

+ 2 - 2
tools/goctl/model/sql/gen/vars.go

@@ -4,7 +4,7 @@ import (
 	"strings"
 
 	"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
 
@@ -14,7 +14,7 @@ func genVars(table Table, withCache bool) (string, error) {
 		keys = append(keys, v.VarExpression)
 	}
 	camel := table.Name.ToCamel()
-	output, err := util.With("var").
+	output, err := templatex.With("var").
 		Parse(template.Vars).
 		GoFmt(true).
 		Execute(map[string]interface{}{

+ 6 - 5
tools/goctl/rpc/gen/gencall.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/tal-tech/go-zero/core/collection"
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 )
 
@@ -122,8 +123,8 @@ func (g *defaultRpcGenerator) genCall() error {
 	}
 
 	filename := filepath.Join(callPath, typesFilename)
-	head := util.GetHead(g.Ctx.ProtoSource)
-	err = util.With("types").GoFmt(true).Parse(callTemplateTypes).SaveTo(map[string]interface{}{
+	head := templatex.GetHead(g.Ctx.ProtoSource)
+	err = templatex.With("types").GoFmt(true).Parse(callTemplateTypes).SaveTo(map[string]interface{}{
 		"head":                  head,
 		"const":                 constLit,
 		"filePackage":           service.Name.Lower(),
@@ -146,7 +147,7 @@ func (g *defaultRpcGenerator) genCall() error {
 		return err
 	}
 
-	err = util.With("shared").GoFmt(true).Parse(callTemplateText).SaveTo(map[string]interface{}{
+	err = templatex.With("shared").GoFmt(true).Parse(callTemplateText).SaveTo(map[string]interface{}{
 		"name":        service.Name.Lower(),
 		"head":        head,
 		"filePackage": service.Name.Lower(),
@@ -166,7 +167,7 @@ func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string,
 	imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb)))
 	for _, method := range service.Funcs {
 		imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
-		buffer, err := util.With("sharedFn").Parse(callFunctionTemplate).Execute(map[string]interface{}{
+		buffer, err := templatex.With("sharedFn").Parse(callFunctionTemplate).Execute(map[string]interface{}{
 			"rpcServiceName": service.Name.Title(),
 			"method":         method.Name.Title(),
 			"package":        pkgName,
@@ -189,7 +190,7 @@ func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]s
 	functions := make([]string, 0)
 
 	for _, method := range service.Funcs {
-		buffer, err := util.With("interfaceFn").Parse(callInterfaceFunctionTemplate).Execute(
+		buffer, err := templatex.With("interfaceFn").Parse(callInterfaceFunctionTemplate).Execute(
 			map[string]interface{}{
 				"hasComment": method.HaveDoc(),
 				"comment":    method.GetDoc(),

+ 2 - 1
tools/goctl/rpc/gen/genetc.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"path/filepath"
 
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 )
 
@@ -22,7 +23,7 @@ func (g *defaultRpcGenerator) genEtc() error {
 		return nil
 	}
 
-	return util.With("etc").Parse(etcTemplate).SaveTo(map[string]interface{}{
+	return templatex.With("etc").Parse(etcTemplate).SaveTo(map[string]interface{}{
 		"serviceName": g.Ctx.ServiceName.Lower(),
 	}, fileName, false)
 }

+ 3 - 2
tools/goctl/rpc/gen/genlogic.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/tal-tech/go-zero/core/collection"
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 )
 
@@ -61,7 +62,7 @@ func (g *defaultRpcGenerator) genLogic() error {
 			svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
 			imports.AddStr(svcImport)
 			imports.AddStr(importList...)
-			err = util.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{
+			err = templatex.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{
 				"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
 				"functions": functions,
 				"imports":   strings.Join(imports.KeysStr(), util.NL),
@@ -82,7 +83,7 @@ func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parse
 	}
 	imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
 	imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
-	buffer, err := util.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{
+	buffer, err := templatex.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{
 		"logicName":    fmt.Sprintf("%sLogic", method.Name.Title()),
 		"method":       method.Name.Title(),
 		"request":      method.ParameterIn.StarExpression,

+ 3 - 2
tools/goctl/rpc/gen/genmain.go

@@ -6,6 +6,7 @@ import (
 	"strings"
 
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 )
 
@@ -58,8 +59,8 @@ func (g *defaultRpcGenerator) genMain() error {
 	configImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig))
 	imports = append(imports, configImport, pbImport, remoteImport, svcImport)
 	srv, registers := g.genServer(pkg, file.Service)
-	head := util.GetHead(g.Ctx.ProtoSource)
-	return util.With("main").GoFmt(true).Parse(mainTemplate).SaveTo(map[string]interface{}{
+	head := templatex.GetHead(g.Ctx.ProtoSource)
+	return templatex.With("main").GoFmt(true).Parse(mainTemplate).SaveTo(map[string]interface{}{
 		"head":        head,
 		"package":     pkg,
 		"serviceName": g.Ctx.ServiceName.Lower(),

+ 4 - 3
tools/goctl/rpc/gen/genserver.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/tal-tech/go-zero/core/collection"
 	"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 )
 
@@ -51,7 +52,7 @@ func (g *defaultRpcGenerator) genHandler() error {
 	imports := collection.NewSet()
 	imports.AddStr(logicImport, svcImport)
 
-	head := util.GetHead(g.Ctx.ProtoSource)
+	head := templatex.GetHead(g.Ctx.ProtoSource)
 	for _, service := range file.Service {
 		filename := fmt.Sprintf("%vserver.go", service.Name.Lower())
 		serverFile := filepath.Join(serverPath, filename)
@@ -60,7 +61,7 @@ func (g *defaultRpcGenerator) genHandler() error {
 			return err
 		}
 		imports.AddStr(importList...)
-		err = util.With("server").GoFmt(true).Parse(serverTemplate).SaveTo(map[string]interface{}{
+		err = templatex.With("server").GoFmt(true).Parse(serverTemplate).SaveTo(map[string]interface{}{
 			"head":    head,
 			"types":   fmt.Sprintf(typeFmt, service.Name.Title()),
 			"server":  service.Name.Title(),
@@ -85,7 +86,7 @@ func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string
 		}
 		imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
 		imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
-		buffer, err := util.With("func").Parse(functionTemplate).Execute(map[string]interface{}{
+		buffer, err := templatex.With("func").Parse(functionTemplate).Execute(map[string]interface{}{
 			"server":     service.Name.Title(),
 			"logicName":  fmt.Sprintf("%sLogic", method.Name.Title()),
 			"method":     method.Name.Title(),

+ 2 - 2
tools/goctl/rpc/gen/gensvc.go

@@ -4,7 +4,7 @@ import (
 	"fmt"
 	"path/filepath"
 
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 )
 
 const svcTemplate = `package svc
@@ -25,7 +25,7 @@ func NewServiceContext(c config.Config) *ServiceContext {
 func (g *defaultRpcGenerator) genSvc() error {
 	svcPath := g.dirM[dirSvc]
 	fileName := filepath.Join(svcPath, fileServiceContext)
-	return util.With("svc").GoFmt(true).Parse(svcTemplate).SaveTo(map[string]interface{}{
+	return templatex.With("svc").GoFmt(true).Parse(svcTemplate).SaveTo(map[string]interface{}{
 		"imports": fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig)),
 	}, fileName, false)
 }

+ 2 - 2
tools/goctl/rpc/gen/template.go

@@ -4,7 +4,7 @@ import (
 	"path/filepath"
 	"strings"
 
-	"github.com/tal-tech/go-zero/tools/goctl/util"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util/console"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
@@ -43,7 +43,7 @@ func (r *rpcTemplate) MustGenerate(showState bool) {
 	r.Info("generating template...")
 	protoFilename := filepath.Base(r.out)
 	serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename)))
-	err := util.With("t").Parse(rpcTemplateText).SaveTo(map[string]string{
+	err := templatex.With("t").Parse(rpcTemplateText).SaveTo(map[string]string{
 		"package":     serviceName.UnTitle(),
 		"serviceName": serviceName.Title(),
 	}, r.out, false)

+ 4 - 3
tools/goctl/rpc/parser/pbast.go

@@ -12,6 +12,7 @@ import (
 
 	"github.com/tal-tech/go-zero/core/lang"
 	sx "github.com/tal-tech/go-zero/core/stringx"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 	"github.com/tal-tech/go-zero/tools/goctl/util/console"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
@@ -589,7 +590,7 @@ func (a *PbAst) GenTypesCode() (string, error) {
 		types = append(types, typeCode)
 	}
 
-	buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{
+	buffer, err := templatex.With("type").Parse(typeTemplate).Execute(map[string]interface{}{
 		"types": strings.Join(types, util.NL+util.NL),
 	})
 	if err != nil {
@@ -614,7 +615,7 @@ func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
 			comment = f.Comment[0]
 		}
 		doc = strings.Join(f.Document, util.NL)
-		buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{
+		buffer, err := templatex.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{
 			"name":       f.Name.Title(),
 			"type":       f.Type.InvokeTypeExpression,
 			"tag":        f.JsonTag,
@@ -629,7 +630,7 @@ func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
 
 		fields = append(fields, buffer.String())
 	}
-	buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{
+	buffer, err := templatex.With("struct").Parse(structTemplate).Execute(map[string]interface{}{
 		"type":   containsTypeStatement,
 		"name":   s.Name.Title(),
 		"fields": strings.Join(fields, util.NL),

+ 4 - 3
tools/goctl/rpc/parser/proto.go

@@ -10,6 +10,7 @@ import (
 	"github.com/emicklei/proto"
 	"github.com/tal-tech/go-zero/core/collection"
 	"github.com/tal-tech/go-zero/core/lang"
+	"github.com/tal-tech/go-zero/tools/goctl/templatex"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
 	"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
 )
@@ -262,7 +263,7 @@ func (e *Enum) GenEnumCode() (string, error) {
 		}
 		element = append(element, code)
 	}
-	buffer, err := util.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{
+	buffer, err := templatex.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{
 		"element": strings.Join(element, util.NL),
 	})
 	if err != nil {
@@ -272,7 +273,7 @@ func (e *Enum) GenEnumCode() (string, error) {
 }
 
 func (e *Enum) GenEnumTypeCode() (string, error) {
-	buffer, err := util.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{
+	buffer, err := templatex.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{
 		"name": e.Name.Source(),
 	})
 	if err != nil {
@@ -282,7 +283,7 @@ func (e *Enum) GenEnumTypeCode() (string, error) {
 }
 
 func (e *EnumField) GenEnumFieldCode(parentName string) (string, error) {
-	buffer, err := util.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{
+	buffer, err := templatex.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{
 		"key":   e.Key,
 		"name":  parentName,
 		"value": e.Value,

+ 79 - 0
tools/goctl/templatex/files.go

@@ -0,0 +1,79 @@
+package templatex
+
+import (
+	"fmt"
+	"io/ioutil"
+	"os"
+	"path/filepath"
+
+	"github.com/logrusorgru/aurora"
+	"github.com/tal-tech/go-zero/tools/goctl/util"
+)
+
+const goctlDir = ".goctl"
+
+func InitTemplates(category string, templates map[string]string) error {
+	dir, err := getTemplateDir(category)
+	if err != nil {
+		return err
+	}
+
+	if err := util.MkdirIfNotExist(dir); err != nil {
+		return err
+	}
+
+	for k, v := range templates {
+		if err := createTemplate(filepath.Join(dir, k), v); err != nil {
+			return err
+		}
+	}
+
+	fmt.Printf("Templates are generated in %s, %s\n", aurora.Green(dir),
+		aurora.Red("edit on your risk!"))
+
+	return nil
+}
+
+func LoadTemplate(category, file, builtin string) (string, error) {
+	dir, err := getTemplateDir(category)
+	if err != nil {
+		return "", err
+	}
+
+	file = filepath.Join(dir, file)
+	if !util.FileExists(file) {
+		return builtin, nil
+	}
+
+	content, err := ioutil.ReadFile(file)
+	if err != nil {
+		return "", err
+	}
+
+	return string(content), nil
+}
+
+func createTemplate(file, content string) error {
+	if util.FileExists(file) {
+		println(1)
+		return nil
+	}
+
+	f, err := os.Create(file)
+	if err != nil {
+		return err
+	}
+	defer f.Close()
+
+	_, err = f.WriteString(content)
+	return err
+}
+
+func getTemplateDir(category string) (string, error) {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return "", err
+	}
+
+	return filepath.Join(home, goctlDir, category), nil
+}

+ 1 - 1
tools/goctl/util/head.go → tools/goctl/templatex/head.go

@@ -1,4 +1,4 @@
-package util
+package templatex
 
 var headTemplate = `// Code generated by goctl. DO NOT EDIT!
 // Source: {{.source}}`

+ 21 - 19
tools/goctl/util/templatex.go → tools/goctl/templatex/templatex.go

@@ -1,22 +1,23 @@
-package util
+package templatex
 
 import (
 	"bytes"
 	goformat "go/format"
 	"io/ioutil"
-	"os"
 	"text/template"
-)
 
-type (
-	defaultTemplate struct {
-		name     string
-		text     string
-		goFmt    bool
-		savePath string
-	}
+	"github.com/tal-tech/go-zero/tools/goctl/util"
 )
 
+const regularPerm = 0666
+
+type defaultTemplate struct {
+	name     string
+	text     string
+	goFmt    bool
+	savePath string
+}
+
 func With(name string) *defaultTemplate {
 	return &defaultTemplate{
 		name: name,
@@ -33,37 +34,38 @@ func (t *defaultTemplate) GoFmt(format bool) *defaultTemplate {
 }
 
 func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error {
-	if FileExists(path) && !forceUpdate {
+	if util.FileExists(path) && !forceUpdate {
 		return nil
 	}
-	output, err := t.execute(data)
+
+	output, err := t.Execute(data)
 	if err != nil {
 		return err
 	}
-	return ioutil.WriteFile(path, output.Bytes(), os.ModePerm)
-}
 
-func (t *defaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) {
-	return t.execute(data)
+	return ioutil.WriteFile(path, output.Bytes(), regularPerm)
 }
 
-func (t *defaultTemplate) execute(data interface{}) (*bytes.Buffer, error) {
+func (t *defaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) {
 	tem, err := template.New(t.name).Parse(t.text)
 	if err != nil {
 		return nil, err
 	}
+
 	buf := new(bytes.Buffer)
-	err = tem.Execute(buf, data)
-	if err != nil {
+	if err = tem.Execute(buf, data); err != nil {
 		return nil, err
 	}
+
 	if !t.goFmt {
 		return buf, nil
 	}
+
 	formatOutput, err := goformat.Source(buf.Bytes())
 	if err != nil {
 		return nil, err
 	}
+
 	buf.Reset()
 	buf.Write(formatOutput)
 	return buf, nil