Bläddra i källkod

feat(goctl): supports api multi-level importing (#1747)

* feat(goctl): supports api  multi-level importing

Resolves: #1744

* fix(goctl): import-cycle, etc.

import-cycle will not be allowed
e.g., a.api -> b.api -> a.api
regular multiple-import will be allowed
e.g., a.api -> b.api -> c.api
                   -> c.api

* refactor(goctl): adds comments to exported var

* fix(goctl): typo in a comment
Fyn 3 år sedan
förälder
incheckning
6d9dfc08f9

+ 81 - 45
tools/goctl/api/parser/g4/ast/apiparser.go

@@ -15,12 +15,18 @@ import (
 type (
 	// Parser provides api parsing capabilities
 	Parser struct {
-		linePrefix string
-		debug      bool
-		log        console.Console
 		antlr.DefaultErrorListener
+		linePrefix               string
+		debug                    bool
+		log                      console.Console
 		src                      string
 		skipCheckTypeDeclaration bool
+		handlerMap               map[string]PlaceHolder
+		routeMap                 map[string]PlaceHolder
+		typeMap                  map[string]PlaceHolder
+		fileMap                  map[string]PlaceHolder
+		importStatck             importStack
+		syntax                   *SyntaxExpr
 	}
 
 	// ParserOption defines an function with argument Parser
@@ -35,6 +41,10 @@ func NewParser(options ...ParserOption) *Parser {
 	for _, opt := range options {
 		opt(p)
 	}
+	p.handlerMap = make(map[string]PlaceHolder)
+	p.routeMap = make(map[string]PlaceHolder)
+	p.typeMap = make(map[string]PlaceHolder)
+	p.fileMap = make(map[string]PlaceHolder)
 
 	return p
 }
@@ -84,6 +94,7 @@ func (p *Parser) Parse(filename string) (*Api, error) {
 		return nil, err
 	}
 
+	p.importStatck.push(p.src)
 	return p.parse(filename, data)
 }
 
@@ -100,6 +111,7 @@ func (p *Parser) ParseContent(content string, filename ...string) (*Api, error)
 		p.src = abs
 	}
 
+	p.importStatck.push(p.src)
 	return p.parse(f, content)
 }
 
@@ -113,12 +125,43 @@ func (p *Parser) parse(filename, content string) (*Api, error) {
 
 	var apiAstList []*Api
 	apiAstList = append(apiAstList, root)
-	for _, imp := range root.Import {
+	p.storeVerificationInfo(root)
+	p.syntax = root.Syntax
+	impApiAstList, err := p.invokeImportedApi(root.Import)
+	if err != nil {
+		return nil, err
+	}
+	apiAstList = append(apiAstList, impApiAstList...)
+
+	if !p.skipCheckTypeDeclaration {
+		err = p.checkTypeDeclaration(apiAstList)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	allApi := p.memberFill(apiAstList)
+	return allApi, nil
+}
+
+func (p *Parser) invokeImportedApi(imports []*ImportExpr) ([]*Api, error) {
+	var apiAstList []*Api
+	for _, imp := range imports {
 		dir := filepath.Dir(p.src)
 		impPath := strings.ReplaceAll(imp.Value.Text(), "\"", "")
 		if !filepath.IsAbs(impPath) {
 			impPath = filepath.Join(dir, impPath)
 		}
+		// import cycle check
+		if err := p.importStatck.push(impPath); err != nil {
+			return nil, err
+		}
+		// ignore already imported file
+		if p.alreadyImported(impPath) {
+			continue
+		}
+		p.fileMap[impPath] = PlaceHolder{}
+
 		data, err := p.readContent(impPath)
 		if err != nil {
 			return nil, err
@@ -129,23 +172,26 @@ func (p *Parser) parse(filename, content string) (*Api, error) {
 			return nil, err
 		}
 
-		err = p.valid(root, nestedApi)
+		err = p.valid(nestedApi)
 		if err != nil {
 			return nil, err
 		}
-
+		p.storeVerificationInfo(nestedApi)
 		apiAstList = append(apiAstList, nestedApi)
-	}
+		list, err := p.invokeImportedApi(nestedApi.Import)
+		p.importStatck.pop()
+		apiAstList = append(apiAstList, list...)
 
-	if !p.skipCheckTypeDeclaration {
-		err = p.checkTypeDeclaration(apiAstList)
 		if err != nil {
 			return nil, err
 		}
 	}
+	return apiAstList, nil
+}
 
-	allApi := p.memberFill(apiAstList)
-	return allApi, nil
+func (p *Parser) alreadyImported(filename string) bool {
+	_, ok := p.fileMap[filename]
+	return ok
 }
 
 func (p *Parser) invoke(linePrefix, content string) (v *Api, err error) {
@@ -184,58 +230,48 @@ func (p *Parser) invoke(linePrefix, content string) (v *Api, err error) {
 	return
 }
 
-func (p *Parser) valid(mainApi, nestedApi *Api) error {
-	err := p.nestedApiCheck(mainApi, nestedApi)
-	if err != nil {
-		return err
-	}
-
-	mainHandlerMap := make(map[string]PlaceHolder)
-	mainRouteMap := make(map[string]PlaceHolder)
-	mainTypeMap := make(map[string]PlaceHolder)
-
-	routeMap := func(list []*ServiceRoute) (map[string]PlaceHolder, map[string]PlaceHolder) {
-		handlerMap := make(map[string]PlaceHolder)
-		routeMap := make(map[string]PlaceHolder)
-
+// storeVerificationInfo stores information for verification
+func (p *Parser) storeVerificationInfo(api *Api) {
+	routeMap := func(list []*ServiceRoute) {
 		for _, g := range list {
 			handler := g.GetHandler()
 			if handler.IsNotNil() {
 				handlerName := handler.Text()
-				handlerMap[handlerName] = Holder
+				p.handlerMap[handlerName] = Holder
 				route := fmt.Sprintf("%s://%s", g.Route.Method.Text(), g.Route.Path.Text())
-				routeMap[route] = Holder
+				p.routeMap[route] = Holder
 			}
 		}
+	}
 
-		return handlerMap, routeMap
+	for _, each := range api.Service {
+		routeMap(each.ServiceApi.ServiceRoute)
 	}
 
-	for _, each := range mainApi.Service {
-		h, r := routeMap(each.ServiceApi.ServiceRoute)
+	for _, each := range api.Type {
+		p.typeMap[each.NameExpr().Text()] = Holder
+	}
+}
 
-		for k, v := range h {
-			mainHandlerMap[k] = v
-		}
+func (p *Parser) valid(nestedApi *Api) error {
 
-		for k, v := range r {
-			mainRouteMap[k] = v
+	if p.syntax != nil && nestedApi.Syntax != nil {
+		if p.syntax.Version.Text() != nestedApi.Syntax.Version.Text() {
+			syntaxToken := nestedApi.Syntax.Syntax
+			return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'",
+				nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), p.syntax.Version.Text(), nestedApi.Syntax.Version.Text())
 		}
 	}
 
-	for _, each := range mainApi.Type {
-		mainTypeMap[each.NameExpr().Text()] = Holder
-	}
-
 	// duplicate route check
-	err = p.duplicateRouteCheck(nestedApi, mainHandlerMap, mainRouteMap)
+	err := p.duplicateRouteCheck(nestedApi)
 	if err != nil {
 		return err
 	}
 
 	// duplicate type check
 	for _, each := range nestedApi.Type {
-		if _, ok := mainTypeMap[each.NameExpr().Text()]; ok {
+		if _, ok := p.typeMap[each.NameExpr().Text()]; ok {
 			return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'",
 				nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text())
 		}
@@ -244,7 +280,7 @@ func (p *Parser) valid(mainApi, nestedApi *Api) error {
 	return nil
 }
 
-func (p *Parser) duplicateRouteCheck(nestedApi *Api, mainHandlerMap, mainRouteMap map[string]PlaceHolder) error {
+func (p *Parser) duplicateRouteCheck(nestedApi *Api) error {
 	for _, each := range nestedApi.Service {
 		var prefix, group string
 		if each.AtServer != nil {
@@ -267,13 +303,13 @@ func (p *Parser) duplicateRouteCheck(nestedApi *Api, mainHandlerMap, mainRouteMa
 			if len(group) > 0 {
 				handlerKey = fmt.Sprintf("%s/%s", group, handler.Text())
 			}
-			if _, ok := mainHandlerMap[handlerKey]; ok {
+			if _, ok := p.handlerMap[handlerKey]; ok {
 				return fmt.Errorf("%s line %d:%d duplicate handler '%s'",
 					nestedApi.LinePrefix, handler.Line(), handler.Column(), handlerKey)
 			}
 
-			p := fmt.Sprintf("%s://%s", r.Route.Method.Text(), path.Join(prefix, r.Route.Path.Text()))
-			if _, ok := mainRouteMap[p]; ok {
+			routeKey := fmt.Sprintf("%s://%s", r.Route.Method.Text(), path.Join(prefix, r.Route.Path.Text()))
+			if _, ok := p.routeMap[routeKey]; ok {
 				return fmt.Errorf("%s line %d:%d duplicate route '%s'",
 					nestedApi.LinePrefix, r.Route.Method.Line(), r.Route.Method.Column(), r.Route.Method.Text()+" "+r.Route.Path.Text())
 			}

+ 99 - 0
tools/goctl/api/parser/g4/ast/apiparser_test.go

@@ -0,0 +1,99 @@
+package ast
+
+import (
+	"io/ioutil"
+	"os"
+	"path/filepath"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
+)
+
+func Test_ImportCycle(t *testing.T) {
+	const (
+		mainFilename = "main.api"
+		subAFilename = "a.api"
+		subBFilename = "b.api"
+		mainSrc      = `import "./a.api"`
+		subASrc      = `import "./b.api"`
+		subBSrc      = `import "./a.api"`
+	)
+	var err error
+	dir := pathx.MustTempDir()
+	defer os.RemoveAll(dir)
+
+	mainPath := filepath.Join(dir, mainFilename)
+	err = ioutil.WriteFile(mainPath, []byte(mainSrc), 0777)
+	require.NoError(t, err)
+	subAPath := filepath.Join(dir, subAFilename)
+	err = ioutil.WriteFile(subAPath, []byte(subASrc), 0777)
+	require.NoError(t, err)
+	subBPath := filepath.Join(dir, subBFilename)
+	err = ioutil.WriteFile(subBPath, []byte(subBSrc), 0777)
+	require.NoError(t, err)
+
+	_, err = NewParser().Parse(mainPath)
+	assert.ErrorIs(t, err, ErrImportCycleNotAllowed)
+}
+
+func Test_MultiImportedShouldAllowed(t *testing.T) {
+	const (
+		mainFilename = "main.api"
+		subAFilename = "a.api"
+		subBFilename = "b.api"
+		mainSrc      = "import \"./b.api\"\n" +
+			"import \"./a.api\"\n" +
+			"type Main { b B `json:\"b\"`}"
+		subASrc = "import \"./b.api\"\n" +
+			"type A { b B `json:\"b\"`}\n"
+		subBSrc = `type B{}`
+	)
+	var err error
+	dir := pathx.MustTempDir()
+	defer os.RemoveAll(dir)
+
+	mainPath := filepath.Join(dir, mainFilename)
+	err = ioutil.WriteFile(mainPath, []byte(mainSrc), 0777)
+	require.NoError(t, err)
+	subAPath := filepath.Join(dir, subAFilename)
+	err = ioutil.WriteFile(subAPath, []byte(subASrc), 0777)
+	require.NoError(t, err)
+	subBPath := filepath.Join(dir, subBFilename)
+	err = ioutil.WriteFile(subBPath, []byte(subBSrc), 0777)
+	require.NoError(t, err)
+
+	_, err = NewParser().Parse(mainPath)
+	assert.NoError(t, err)
+}
+
+func Test_RedundantDeclarationShouldNotBeAllowed(t *testing.T) {
+	const (
+		mainFilename = "main.api"
+		subAFilename = "a.api"
+		subBFilename = "b.api"
+		mainSrc      = "import \"./a.api\"\n" +
+			"import \"./b.api\"\n"
+		subASrc = `import "./b.api"
+							 type A{}`
+		subBSrc = `type A{}`
+	)
+	var err error
+	dir := pathx.MustTempDir()
+	defer os.RemoveAll(dir)
+
+	mainPath := filepath.Join(dir, mainFilename)
+	err = ioutil.WriteFile(mainPath, []byte(mainSrc), 0777)
+	require.NoError(t, err)
+	subAPath := filepath.Join(dir, subAFilename)
+	err = ioutil.WriteFile(subAPath, []byte(subASrc), 0777)
+	require.NoError(t, err)
+	subBPath := filepath.Join(dir, subBFilename)
+	err = ioutil.WriteFile(subBPath, []byte(subBSrc), 0777)
+	require.NoError(t, err)
+
+	_, err = NewParser().Parse(mainPath)
+	assert.Error(t, err)
+	assert.Contains(t, err.Error(), "duplicate type declaration")
+}

+ 23 - 0
tools/goctl/api/parser/g4/ast/importstack.go

@@ -0,0 +1,23 @@
+package ast
+
+import "errors"
+
+// ErrImportCycleNotAllowed defines an error for circular importing
+var ErrImportCycleNotAllowed = errors.New("import cycle not allowed")
+
+// importStack a stack of import paths
+type importStack []string
+
+func (s *importStack) push(p string) error {
+	for _, x := range *s {
+		if x == p {
+			return ErrImportCycleNotAllowed
+		}
+	}
+	*s = append(*s, p)
+	return nil
+}
+
+func (s *importStack) pop() {
+	*s = (*s)[0 : len(*s)-1]
+}