Przeglądaj źródła

feat: Support model code generation for multi tables (#1836)

* Support model code generation for multi tables

* Format code

* Format code

Co-authored-by: anqiansong <anqiansong@bytedance.com>
anqiansong 3 lat temu
rodzic
commit
cc09ab2aba

+ 1 - 1
tools/goctl/goctl.go

@@ -725,7 +725,7 @@ var commands = []cli.Command{
 								Name:  "url",
 								Usage: `the data source of database,like "root:password@tcp(127.0.0.1:3306)/database"`,
 							},
-							cli.StringFlag{
+							cli.StringSliceFlag{
 								Name:  "table, t",
 								Usage: `the table or table globbing patterns in the database`,
 							},

+ 43 - 10
tools/goctl/model/sql/command/command.go

@@ -87,13 +87,51 @@ func MySqlDataSource(ctx *cli.Context) error {
 		pathx.RegisterGoctlHome(home)
 	}
 
-	pattern := strings.TrimSpace(ctx.String(flagTable))
+	tableValue := ctx.StringSlice(flagTable)
+	patterns := parseTableList(tableValue)
 	cfg, err := config.NewConfig(style)
 	if err != nil {
 		return err
 	}
 
-	return fromMysqlDataSource(url, pattern, dir, cfg, cache, idea)
+	return fromMysqlDataSource(url, dir, patterns, cfg, cache, idea)
+}
+
+type pattern map[string]struct{}
+
+func (p pattern) Match(s string) bool {
+	for v := range p {
+		match, err := filepath.Match(v, s)
+		if err != nil {
+			console.Error("%+v", err)
+			continue
+		}
+		if match {
+			return true
+		}
+	}
+	return false
+}
+
+func (p pattern) list() []string {
+	var ret []string
+	for v := range p {
+		ret = append(ret, v)
+	}
+	return ret
+}
+
+func parseTableList(tableValue []string) pattern {
+	tablePattern := make(pattern)
+	for _, v := range tableValue {
+		fields := strings.FieldsFunc(v, func(r rune) bool {
+			return r == ','
+		})
+		for _, f := range fields {
+			tablePattern[f] = struct{}{}
+		}
+	}
+	return tablePattern
 }
 
 // PostgreSqlDataSource generates model code from datasource
@@ -162,14 +200,14 @@ func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database str
 	return nil
 }
 
-func fromMysqlDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bool) error {
+func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config, cache, idea bool) error {
 	log := console.NewConsole(idea)
 	if len(url) == 0 {
 		log.Error("%v", "expected data source of mysql, but nothing found")
 		return nil
 	}
 
-	if len(pattern) == 0 {
+	if len(tablePat) == 0 {
 		log.Error("%v", "expected table or table globbing patterns, but nothing found")
 		return nil
 	}
@@ -191,12 +229,7 @@ func fromMysqlDataSource(url, pattern, dir string, cfg *config.Config, cache, id
 
 	matchTables := make(map[string]*model.Table)
 	for _, item := range tables {
-		match, err := filepath.Match(pattern, item)
-		if err != nil {
-			return err
-		}
-
-		if !match {
+		if !tablePat.Match(item) {
 			continue
 		}
 

+ 29 - 0
tools/goctl/model/sql/command/command_test.go

@@ -5,6 +5,8 @@ import (
 	"io/ioutil"
 	"os"
 	"path/filepath"
+	"sort"
+	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
@@ -86,3 +88,30 @@ func TestFromDDl(t *testing.T) {
 	_ = os.Remove(filename)
 	fromDDL("1gozero")
 }
+
+func Test_parseTableList(t *testing.T) {
+	testData := []string{"foo", "b*", "bar", "back_up", "foo,bar,b*"}
+	patterns := parseTableList(testData)
+	actual := patterns.list()
+	expected := []string{"foo", "b*", "bar", "back_up"}
+	sort.Slice(actual, func(i, j int) bool {
+		return actual[i] > actual[j]
+	})
+	sort.Slice(expected, func(i, j int) bool {
+		return expected[i] > expected[j]
+	})
+	assert.Equal(t, strings.Join(expected, ","), strings.Join(actual, ","))
+
+	matchTestData := map[string]bool{
+		"foo":     true,
+		"bar":     true,
+		"back_up": true,
+		"bit":     true,
+		"ab":      false,
+		"b":       true,
+	}
+	for v, expected := range matchTestData {
+		actual := patterns.Match(v)
+		assert.Equal(t, expected, actual)
+	}
+}