浏览代码

fixes #987 (#1283)

* fixes #987

* chore: fix test failure

* chore: add comments
Kevin Wan 3 年之前
父节点
当前提交
543d590710

+ 43 - 7
tools/goctl/api/gogen/genlogic.go

@@ -3,8 +3,10 @@ package gogen
 import (
 	"fmt"
 	"path"
+	"strconv"
 	"strings"
 
+	"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
 	"github.com/tal-tech/go-zero/tools/goctl/api/spec"
 	"github.com/tal-tech/go-zero/tools/goctl/config"
 	ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
@@ -64,12 +66,8 @@ func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group,
 	var requestString string
 	if len(route.ResponseTypeName()) > 0 {
 		resp := responseGoTypeName(route, typesPacket)
-		responseString = "(" + resp + ", error)"
-		if strings.HasPrefix(resp, "*") {
-			returnString = fmt.Sprintf("return &%s{}, nil", strings.TrimPrefix(resp, "*"))
-		} else {
-			returnString = fmt.Sprintf("return %s{}, nil", resp)
-		}
+		responseString = "(resp " + resp + ", err error)"
+		returnString = "return"
 	} else {
 		responseString = "error"
 		returnString = "return nil"
@@ -116,9 +114,47 @@ func genLogicImports(route spec.Route, parentPkg string) string {
 	var imports []string
 	imports = append(imports, `"context"`+"\n")
 	imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, contextDir)))
-	if len(route.ResponseTypeName()) > 0 || len(route.RequestTypeName()) > 0 {
+	if shallImportTypesPackage(route) {
 		imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir)))
 	}
 	imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL))
 	return strings.Join(imports, "\n\t")
 }
+
+func onlyPrimitiveTypes(val string) bool {
+	fields := strings.FieldsFunc(val, func(r rune) bool {
+		return r == '[' || r == ']' || r == ' '
+	})
+
+	for _, field := range fields {
+		if field == "map" {
+			continue
+		}
+		// ignore array dimension number, like [5]int
+		if _, err := strconv.Atoi(field); err == nil {
+			continue
+		}
+		if !api.IsBasicType(field) {
+			return false
+		}
+	}
+
+	return true
+}
+
+func shallImportTypesPackage(route spec.Route) bool {
+	if len(route.RequestTypeName()) > 0 {
+		return true
+	}
+
+	respTypeName := route.ResponseTypeName()
+	if len(respTypeName) == 0 {
+		return false
+	}
+
+	if onlyPrimitiveTypes(respTypeName) {
+		return false
+	}
+
+	return true
+}

+ 2 - 5
tools/goctl/api/parser/g4/ast/service.go

@@ -267,11 +267,8 @@ func (v *ApiVisitor) VisitReplybody(ctx *api.ReplybodyContext) interface{} {
 		}
 	case *Literal:
 		lit := dataType.Literal.Text()
-		if api.IsGolangKeyWord(dataType.Literal.Text()) {
-			v.panic(dataType.Literal, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", dataType.Literal.Text()))
-		}
-		if api.IsBasicType(lit) {
-			v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
+		if api.IsGolangKeyWord(lit) {
+			v.panic(dataType.Literal, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit))
 		}
 	default:
 		v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))

+ 1 - 1
tools/goctl/api/parser/g4/test/service_test.go

@@ -174,7 +174,7 @@ func TestRoute(t *testing.T) {
 		assert.Error(t, err)
 
 		_, err = parser.Accept(fn, ` post /foo/bar returns (int)`)
-		assert.Error(t, err)
+		assert.Nil(t, err)
 
 		_, err = parser.Accept(fn, ` post /foo/bar returns (*int)`)
 		assert.Error(t, err)