Pārlūkot izejas kodu

optimize code (#579)

* optimize code

* optimize returns & unit test
anqiansong 4 gadi atpakaļ
vecāks
revīzija
888551627c

+ 19 - 11
tools/goctl/api/javagen/gencomponents.go

@@ -103,17 +103,9 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
 }
 
 func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
-	defineStruct, ok := ty.(spec.DefineStruct)
-	if !ok {
-		return errors.New("unsupported type %s" + ty.Name())
-	}
-
-	for _, item := range c.requestTypes {
-		if item.Name() == defineStruct.Name() {
-			if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
-				return nil
-			}
-		}
+	defineStruct, done, err := c.checkStruct(ty)
+	if done {
+		return err
 	}
 
 	modelFile := util.Title(ty.Name()) + ".java"
@@ -181,6 +173,22 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
 	return err
 }
 
+func (c *componentsContext) checkStruct(ty spec.Type) (spec.DefineStruct, bool, error) {
+	defineStruct, ok := ty.(spec.DefineStruct)
+	if !ok {
+		return spec.DefineStruct{}, true, errors.New("unsupported type %s" + ty.Name())
+	}
+
+	for _, item := range c.requestTypes {
+		if item.Name() == defineStruct.Name() {
+			if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
+				return spec.DefineStruct{}, true, nil
+			}
+		}
+	}
+	return defineStruct, false, nil
+}
+
 func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
 	var builder strings.Builder
 	if err := c.writeType(&builder, defineStruct); err != nil {

+ 20 - 11
tools/goctl/api/javagen/util.go

@@ -95,17 +95,9 @@ func specTypeToJava(tp spec.Type) (string, error) {
 			return "", err
 		}
 
-		switch valueType {
-		case "int":
-			return "Integer[]", nil
-		case "long":
-			return "Long[]", nil
-		case "float":
-			return "Float[]", nil
-		case "double":
-			return "Double[]", nil
-		case "boolean":
-			return "Boolean[]", nil
+		s := getBaseType(valueType)
+		if len(s) == 0 {
+			return s, errors.New("unsupported primitive type " + tp.Name())
 		}
 
 		return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
@@ -118,6 +110,23 @@ func specTypeToJava(tp spec.Type) (string, error) {
 	return "", errors.New("unsupported primitive type " + tp.Name())
 }
 
+func getBaseType(valueType string) string {
+	switch valueType {
+	case "int":
+		return "Integer[]"
+	case "long":
+		return "Long[]"
+	case "float":
+		return "Float[]"
+	case "double":
+		return "Double[]"
+	case "boolean":
+		return "Boolean[]"
+	default:
+		return ""
+	}
+}
+
 func primitiveType(tp string) (string, bool) {
 	switch tp {
 	case "string":

+ 6 - 1
tools/goctl/model/sql/command/command_test.go

@@ -6,6 +6,8 @@ import (
 	"path/filepath"
 	"testing"
 
+	"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
+
 	"github.com/stretchr/testify/assert"
 	"github.com/tal-tech/go-zero/tools/goctl/config"
 	"github.com/tal-tech/go-zero/tools/goctl/util"
@@ -19,7 +21,10 @@ var (
 )
 
 func TestFromDDl(t *testing.T) {
-	err := fromDDl("./user.sql", t.TempDir(), cfg, true, false)
+	err := gen.Clean()
+	assert.Nil(t, err)
+
+	err = fromDDl("./user.sql", t.TempDir(), cfg, true, false)
 	assert.Equal(t, errNotMatched, err)
 
 	// case dir is not exists

+ 23 - 21
tools/goctl/model/sql/gen/findonebyfield.go

@@ -25,27 +25,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
 	var list []string
 	camelTableName := table.Name.ToCamel()
 	for _, key := range table.UniqueCacheKey {
-		var inJoin, paramJoin, argJoin Join
-		for _, f := range key.Fields {
-			param := stringx.From(f.Name.ToCamel()).Untitle()
-			inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
-			paramJoin = append(paramJoin, param)
-			argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
-		}
-		var in string
-		if len(inJoin) > 0 {
-			in = inJoin.With(", ").Source()
-		}
-
-		var paramJoinString string
-		if len(paramJoin) > 0 {
-			paramJoinString = paramJoin.With(",").Source()
-		}
-
-		var originalFieldString string
-		if len(argJoin) > 0 {
-			originalFieldString = argJoin.With(" and ").Source()
-		}
+		in, paramJoinString, originalFieldString := convertJoin(key)
 
 		output, err := t.Execute(map[string]interface{}{
 			"upperStartCamelObject":     camelTableName,
@@ -125,3 +105,25 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
 		findOneInterfaceMethod: strings.Join(listMethod, util.NL),
 	}, nil
 }
+
+func convertJoin(key Key) (in, paramJoinString, originalFieldString string) {
+	var inJoin, paramJoin, argJoin Join
+	for _, f := range key.Fields {
+		param := stringx.From(f.Name.ToCamel()).Untitle()
+		inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
+		paramJoin = append(paramJoin, param)
+		argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
+	}
+	if len(inJoin) > 0 {
+		in = inJoin.With(", ").Source()
+	}
+
+	if len(paramJoin) > 0 {
+		paramJoinString = paramJoin.With(",").Source()
+	}
+
+	if len(argJoin) > 0 {
+		originalFieldString = argJoin.With(" and ").Source()
+	}
+	return in, paramJoinString, originalFieldString
+}

+ 40 - 29
tools/goctl/model/sql/parser/parser.go

@@ -102,6 +102,17 @@ func Parse(ddl string) (*Table, error) {
 		}
 	}
 
+	checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
+	return &Table{
+		Name:        stringx.From(tableName),
+		PrimaryKey:  primaryKey,
+		UniqueIndex: uniqueIndex,
+		NormalIndex: normalIndex,
+		Fields:      fields,
+	}, nil
+}
+
+func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
 	log := console.NewColorConsole()
 	uniqueSet := collection.NewSet()
 	for k, i := range uniqueIndex {
@@ -136,14 +147,6 @@ func Parse(ddl string) (*Table, error) {
 
 		normalIndexSet.Add(joinRet)
 	}
-
-	return &Table{
-		Name:        stringx.From(tableName),
-		PrimaryKey:  primaryKey,
-		UniqueIndex: uniqueIndex,
-		NormalIndex: normalIndex,
-		Fields:      fields,
-	}, nil
 }
 
 func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
@@ -289,27 +292,9 @@ func ConvertDataType(table *model.Table) (*Table, error) {
 		AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
 	}
 
-	fieldM := make(map[string]*Field)
-	for _, each := range table.Columns {
-		isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
-		dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
-		if err != nil {
-			return nil, err
-		}
-		columnSeqInIndex := 0
-		if each.Index != nil {
-			columnSeqInIndex = each.Index.SeqInIndex
-		}
-
-		field := &Field{
-			Name:            stringx.From(each.Name),
-			DataBaseType:    each.DataType,
-			DataType:        dt,
-			Comment:         each.Comment,
-			SeqInIndex:      columnSeqInIndex,
-			OrdinalPosition: each.OrdinalPosition,
-		}
-		fieldM[each.Name] = field
+	fieldM, err := getTableFields(table)
+	if err != nil {
+		return nil, err
 	}
 
 	for _, each := range fieldM {
@@ -379,3 +364,29 @@ func ConvertDataType(table *model.Table) (*Table, error) {
 
 	return &reply, nil
 }
+
+func getTableFields(table *model.Table) (map[string]*Field, error) {
+	fieldM := make(map[string]*Field)
+	for _, each := range table.Columns {
+		isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
+		dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
+		if err != nil {
+			return nil, err
+		}
+		columnSeqInIndex := 0
+		if each.Index != nil {
+			columnSeqInIndex = each.Index.SeqInIndex
+		}
+
+		field := &Field{
+			Name:            stringx.From(each.Name),
+			DataBaseType:    each.DataType,
+			DataType:        dt,
+			Comment:         each.Comment,
+			SeqInIndex:      columnSeqInIndex,
+			OrdinalPosition: each.OrdinalPosition,
+		}
+		fieldM[each.Name] = field
+	}
+	return fieldM, nil
+}