瀏覽代碼

feat(goctl): supports model code 'DO NOT EDIT' (#1728)

Resolves: #1710
Fyn 3 年之前
父節點
當前提交
f4471846ff

+ 55 - 10
tools/goctl/model/sql/gen/gen.go

@@ -49,6 +49,11 @@ type (
 		deleteCode  string
 		cacheExtra  string
 	}
+
+	codeTuple struct {
+		modelCode       string
+		modelCustomCode string
+	}
 )
 
 // NewDefaultGenerator creates an instance for defaultGenerator
@@ -109,7 +114,7 @@ func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, databas
 }
 
 func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache bool) error {
-	m := make(map[string]string)
+	m := make(map[string]*codeTuple)
 	for _, each := range tables {
 		table, err := parser.ConvertDataType(each)
 		if err != nil {
@@ -120,14 +125,21 @@ func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.T
 		if err != nil {
 			return err
 		}
+		customCode, err := g.genModelCustom(*table)
+		if err != nil {
+			return err
+		}
 
-		m[table.Name.Source()] = code
+		m[table.Name.Source()] = &codeTuple{
+			modelCode:       code,
+			modelCustomCode: customCode,
+		}
 	}
 
 	return g.createFile(m)
 }
 
-func (g *defaultGenerator) createFile(modelList map[string]string) error {
+func (g *defaultGenerator) createFile(modelList map[string]*codeTuple) error {
 	dirAbs, err := filepath.Abs(g.dir)
 	if err != nil {
 		return err
@@ -140,20 +152,27 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error {
 		return err
 	}
 
-	for tableName, code := range modelList {
+	for tableName, codes := range modelList {
 		tn := stringx.From(tableName)
 		modelFilename, err := format.FileNamingFormat(g.cfg.NamingFormat, fmt.Sprintf("%s_model", tn.Source()))
 		if err != nil {
 			return err
 		}
 
-		name := util.SafeString(modelFilename) + ".go"
+		name := util.SafeString(modelFilename) + "_gen.go"
 		filename := filepath.Join(dirAbs, name)
+		err = ioutil.WriteFile(filename, []byte(codes.modelCode), os.ModePerm)
+		if err != nil {
+			return err
+		}
+
+		name = util.SafeString(modelFilename) + ".go"
+		filename = filepath.Join(dirAbs, name)
 		if pathx.FileExists(filename) {
 			g.Warning("%s already exists, ignored.", name)
 			continue
 		}
-		err = ioutil.WriteFile(filename, []byte(code), os.ModePerm)
+		err = ioutil.WriteFile(filename, []byte(codes.modelCustomCode), os.ModePerm)
 		if err != nil {
 			return err
 		}
@@ -183,8 +202,8 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error {
 }
 
 // ret1: key-table name,value-code
-func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) (map[string]string, error) {
-	m := make(map[string]string)
+func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) (map[string]*codeTuple, error) {
+	m := make(map[string]*codeTuple)
 	tables, err := parser.Parse(filename, database)
 	if err != nil {
 		return nil, err
@@ -195,8 +214,15 @@ func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database
 		if err != nil {
 			return nil, err
 		}
+		customCode, err := g.genModelCustom(*e)
+		if err != nil {
+			return nil, err
+		}
 
-		m[e.Name.Source()] = code
+		m[e.Name.Source()] = &codeTuple{
+			modelCode:       code,
+			modelCustomCode: customCode,
+		}
 	}
 
 	return m, nil
@@ -292,8 +318,27 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
 	return output.String(), nil
 }
 
+func (g *defaultGenerator) genModelCustom(in parser.Table) (string, error) {
+	text, err := pathx.LoadTemplate(category, modelCustomTemplateFile, template.ModelCustom)
+	if err != nil {
+		return "", err
+	}
+	t := util.With("model-custom").
+		Parse(text).
+		GoFmt(true)
+	output, err := t.Execute(map[string]interface{}{
+		"pkg":                   g.pkg,
+		"upperStartCamelObject": in.Name.ToCamel(),
+		"lowerStartCamelObject": stringx.From(in.Name.ToCamel()).Untitle(),
+	})
+	if err != nil {
+		return "", err
+	}
+	return output.String(), nil
+}
+
 func (g *defaultGenerator) executeModel(table Table, code *code) (*bytes.Buffer, error) {
-	text, err := pathx.LoadTemplate(category, modelTemplateFile, template.Model)
+	text, err := pathx.LoadTemplate(category, modelGenTemplateFile, template.ModelGen)
 	if err != nil {
 		return nil, err
 	}

+ 28 - 0
tools/goctl/model/sql/gen/gen_test.go

@@ -4,16 +4,19 @@ import (
 	"database/sql"
 	"io/ioutil"
 	"os"
+	"path"
 	"path/filepath"
 	"strings"
 	"testing"
 	"time"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/stringx"
 	"github.com/zeromicro/go-zero/tools/goctl/config"
 	"github.com/zeromicro/go-zero/tools/goctl/model/sql/builderx"
+	"github.com/zeromicro/go-zero/tools/goctl/model/sql/parser"
 	"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
 )
 
@@ -121,3 +124,28 @@ func TestFields(t *testing.T) {
 	assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet)
 	assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder)
 }
+
+func Test_genPublicModel(t *testing.T) {
+	var err error
+	dir := pathx.MustTempDir()
+	modelDir := path.Join(dir, "model")
+	err = os.MkdirAll(modelDir, 0777)
+	require.NoError(t, err)
+	defer os.RemoveAll(dir)
+
+	modelFilename := filepath.Join(modelDir, "foo.sql")
+	err = ioutil.WriteFile(modelFilename, []byte(source), 0777)
+	require.NoError(t, err)
+
+	g, err := NewDefaultGenerator(modelDir, &config.Config{
+		NamingFormat: config.DefaultFormat,
+	})
+	require.NoError(t, err)
+
+	tables, err := parser.Parse(modelFilename, "")
+	require.Equal(t, 1, len(tables))
+
+	code, err := g.genModelCustom(*tables[0])
+	assert.NoError(t, err)
+	assert.Equal(t, "package model\n\ntype TestUserModel interface {\n\ttestUserModel\n}\n", code)
+}

+ 4 - 2
tools/goctl/model/sql/gen/template.go

@@ -22,7 +22,8 @@ const (
 	importsWithNoCacheTemplateFile        = "import-no-cache.tpl"
 	insertTemplateFile                    = "insert.tpl"
 	insertTemplateMethodFile              = "interface-insert.tpl"
-	modelTemplateFile                     = "model.tpl"
+	modelGenTemplateFile                  = "model-gen.tpl"
+	modelCustomTemplateFile               = "model.tpl"
 	modelNewTemplateFile                  = "model-new.tpl"
 	tagTemplateFile                       = "tag.tpl"
 	typesTemplateFile                     = "types.tpl"
@@ -45,7 +46,8 @@ var templates = map[string]string{
 	importsWithNoCacheTemplateFile:        template.ImportsNoCache,
 	insertTemplateFile:                    template.Insert,
 	insertTemplateMethodFile:              template.InsertMethod,
-	modelTemplateFile:                     template.Model,
+	modelGenTemplateFile:                  template.ModelGen,
+	modelCustomTemplateFile:               template.ModelCustom,
 	modelNewTemplateFile:                  template.New,
 	tagTemplateFile:                       template.Tag,
 	typesTemplateFile:                     template.Types,

+ 2 - 0
tools/goctl/model/sql/gen/types.go

@@ -4,6 +4,7 @@ import (
 	"github.com/zeromicro/go-zero/tools/goctl/model/sql/template"
 	"github.com/zeromicro/go-zero/tools/goctl/util"
 	"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
+	"github.com/zeromicro/go-zero/tools/goctl/util/stringx"
 )
 
 func genTypes(table Table, methods string, withCache bool) (string, error) {
@@ -24,6 +25,7 @@ func genTypes(table Table, methods string, withCache bool) (string, error) {
 			"withCache":             withCache,
 			"method":                methods,
 			"upperStartCamelObject": table.Name.ToCamel(),
+			"lowerStartCamelObject": stringx.From(table.Name.ToCamel()).Untitle(),
 			"fields":                fieldsString,
 			"data":                  table,
 		})

+ 18 - 3
tools/goctl/model/sql/template/model.go

@@ -1,7 +1,15 @@
 package template
 
-// Model defines a template for model
-var Model = `package {{.pkg}}
+import (
+	"fmt"
+
+	"github.com/zeromicro/go-zero/tools/goctl/util"
+)
+
+// ModelGen defines a template for model
+var ModelGen = fmt.Sprintf(`%s
+
+package {{.pkg}}
 {{.imports}}
 {{.vars}}
 {{.types}}
@@ -11,4 +19,11 @@ var Model = `package {{.pkg}}
 {{.update}}
 {{.delete}}
 {{.extraMethod}}
-`
+`, util.DoNotEditHead)
+
+// ModelCustom defines a template for extension
+var ModelCustom = fmt.Sprintf(`package {{.pkg}}
+type {{.upperStartCamelObject}}Model interface {
+	{{.lowerStartCamelObject}}Model
+}
+`)

+ 1 - 1
tools/goctl/model/sql/template/types.go

@@ -3,7 +3,7 @@ package template
 // Types defines a template for types in model
 var Types = `
 type (
-	{{.upperStartCamelObject}}Model interface{
+	{{.lowerStartCamelObject}}Model interface{
 		{{.method}}
 	}
 

+ 2 - 0
tools/goctl/model/sql/template/vars.go

@@ -5,6 +5,8 @@ import "fmt"
 // Vars defines a template for var block in model
 var Vars = fmt.Sprintf(`
 var (
+	_ {{.upperStartCamelObject}}Model = (*default{{.upperStartCamelObject}}Model)(nil)
+
 	{{.lowerStartCamelObject}}FieldNames          = builder.RawFieldNames(&{{.upperStartCamelObject}}{}{{if .postgreSql}},true{{end}})
 	{{.lowerStartCamelObject}}Rows                = strings.Join({{.lowerStartCamelObject}}FieldNames, ",")
 	{{.lowerStartCamelObject}}RowsExpectAutoSet   = {{if .postgreSql}}strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ","){{else}}strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ","){{end}}

+ 3 - 0
tools/goctl/util/head.go

@@ -1,5 +1,8 @@
 package util
 
+// DoNotEditHead added to the beginning of a file to prompt the user not to edit
+var DoNotEditHead = "// Code generated by goctl. DO NOT EDIT!"
+
 var headTemplate = `// Code generated by goctl. DO NOT EDIT!
 // Source: {{.source}}`