Procházet zdrojové kódy

(goctl:) fix circle import in case new parser (#3750)

Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
kesonan před 1 rokem
rodič
revize
5e63002cf8

+ 31 - 0
tools/goctl/pkg/parser/api/importstack/importstack.go

@@ -0,0 +1,31 @@
+package importstack
+
+import "errors"
+
+// ErrImportCycleNotAllowed defines an error for circular importing
+var ErrImportCycleNotAllowed = errors.New("import cycle not allowed")
+
+// ImportStack a stack of import paths
+type ImportStack []string
+
+func New() *ImportStack {
+	return &ImportStack{}
+}
+
+func (s *ImportStack) Push(p string) error {
+	for _, x := range *s {
+		if x == p {
+			return ErrImportCycleNotAllowed
+		}
+	}
+	*s = append(*s, p)
+	return nil
+}
+
+func (s *ImportStack) Pop() {
+	*s = (*s)[0 : len(*s)-1]
+}
+
+func (s *ImportStack) List() []string {
+	return *s
+}

+ 10 - 3
tools/goctl/pkg/parser/api/parser/analyzer.go

@@ -5,8 +5,10 @@ import (
 	"sort"
 	"strings"
 
+	"github.com/zeromicro/go-zero/core/lang"
 	"github.com/zeromicro/go-zero/tools/goctl/api/spec"
 	"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/ast"
+	"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/importstack"
 	"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/placeholder"
 	"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/token"
 )
@@ -390,9 +392,14 @@ func Parse(filename string, src interface{}) (*spec.ApiSpec, error) {
 		return nil, err
 	}
 
-	var importManager = make(map[string]placeholder.Type)
-	importManager[ast.Filename] = placeholder.PlaceHolder
-	api, err := convert2API(ast, importManager)
+	is := importstack.New()
+	err := is.Push(ast.Filename)
+	if err != nil {
+		return nil, err
+	}
+
+	importSet := map[string]lang.PlaceholderType{}
+	api, err := convert2API(ast, importSet, is)
 	if err != nil {
 		return nil, err
 	}

+ 19 - 13
tools/goctl/pkg/parser/api/parser/api.go

@@ -5,7 +5,9 @@ import (
 	"path/filepath"
 	"strings"
 
+	"github.com/zeromicro/go-zero/core/lang"
 	"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/ast"
+	"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/importstack"
 	"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/placeholder"
 	"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/token"
 )
@@ -18,15 +20,17 @@ type API struct {
 	importStmt    []ast.ImportStmt // ImportStmt block does not participate in code generation.
 	TypeStmt      []ast.TypeStmt
 	ServiceStmts  []*ast.ServiceStmt
-	importManager map[string]placeholder.Type
+	importManager *importstack.ImportStack
+	importSet     map[string]lang.PlaceholderType
 }
 
-func convert2API(a *ast.AST, importManager map[string]placeholder.Type) (*API, error) {
+func convert2API(a *ast.AST, importSet map[string]lang.PlaceholderType, is *importstack.ImportStack) (*API, error) {
 	var api = new(API)
-	api.importManager = make(map[string]placeholder.Type)
+	api.importManager = is
+	api.importSet = make(map[string]lang.PlaceholderType)
 	api.Filename = a.Filename
-	for k, v := range importManager {
-		api.importManager[k] = v
+	for k, v := range importSet {
+		api.importSet[k] = v
 	}
 	one := a.Stmts[0]
 	syntax, ok := one.(*ast.SyntaxStmt)
@@ -230,9 +234,6 @@ func (api *API) getAtServerValue(atServer *ast.AtServerStmt, key string) string
 }
 
 func (api *API) mergeAPI(in *API) error {
-	for k, v := range in.importManager {
-		api.importManager[k] = v
-	}
 	if api.Syntax.Value.Format() != in.Syntax.Value.Format() {
 		return ast.SyntaxError(in.Syntax.Value.Pos(),
 			"multiple syntax value expression, expected <%s>, got <%s>",
@@ -269,11 +270,15 @@ func (api *API) parseImportedAPI(imports []ast.ImportStmt) ([]*API, error) {
 			impPath = filepath.Join(dir, impPath)
 		}
 		// import cycle check
-		if _, ok := api.importManager[impPath]; ok {
-			return nil, ast.SyntaxError(tok.Position, "import circle not allowed")
-		} else {
-			api.importManager[impPath] = placeholder.PlaceHolder
+		if err := api.importManager.Push(impPath); err != nil {
+			return nil, ast.SyntaxError(tok.Position, err.Error())
+		}
+
+		if _, ok := api.importSet[impPath]; ok {
+			api.importManager.Pop()
+			continue
 		}
+		api.importSet[impPath] = lang.Placeholder
 
 		p := New(impPath, "")
 		ast := p.Parse()
@@ -281,7 +286,7 @@ func (api *API) parseImportedAPI(imports []ast.ImportStmt) ([]*API, error) {
 			return nil, err
 		}
 
-		nestedApi, err := convert2API(ast, api.importManager)
+		nestedApi, err := convert2API(ast, api.importSet, api.importManager)
 		if err != nil {
 			return nil, err
 		}
@@ -290,6 +295,7 @@ func (api *API) parseImportedAPI(imports []ast.ImportStmt) ([]*API, error) {
 			return nil, err
 		}
 
+		api.importManager.Pop()
 		list = append(list, nestedApi)
 
 		if err != nil {

+ 12 - 8
tools/goctl/pkg/parser/api/scanner/scanner.go

@@ -26,7 +26,6 @@ const (
 	stringOpen
 	stringClose
 	// string mode end
-
 )
 
 var missingInput = errors.New("missing input")
@@ -268,6 +267,7 @@ func (s *Scanner) scanNanosecond(bgPos int) token.Token {
 		return s.illegalToken()
 	}
 	s.readRune()
+
 	return token.Token{
 		Type:     token.DURATION,
 		Text:     string(s.data[bgPos:s.position]),
@@ -485,6 +485,7 @@ func (s *Scanner) scanLineComment() token.Token {
 	for s.ch != '\n' && s.ch != 0 {
 		s.readRune()
 	}
+
 	return token.Token{
 		Type:     token.COMMENT,
 		Text:     string(s.data[position:s.position]),
@@ -546,6 +547,7 @@ func (s *Scanner) assertExpected(actual token.Type, expected ...token.Type) erro
 		strings.Join(expects, " | "),
 		actual.String(),
 	))
+
 	return errors.New(text)
 }
 
@@ -560,6 +562,7 @@ func (s *Scanner) assertExpectedString(actual string, expected ...string) error
 		strings.Join(expects, " | "),
 		actual,
 	))
+
 	return errors.New(text)
 }
 
@@ -647,21 +650,22 @@ func NewScanner(filename string, src interface{}) (*Scanner, error) {
 }
 
 func readData(filename string, src interface{}) ([]byte, error) {
-	data, err := os.ReadFile(filename)
-	if err == nil {
+	if strings.HasSuffix(filename, ".api") {
+		data, err := os.ReadFile(filename)
+		if err != nil {
+			return nil, err
+		}
 		return data, nil
 	}
 
 	switch v := src.(type) {
 	case []byte:
-		data = append(data, v...)
+		return v, nil
 	case *bytes.Buffer:
-		data = v.Bytes()
+		return v.Bytes(), nil
 	case string:
-		data = []byte(v)
+		return []byte(v), nil
 	default:
 		return nil, fmt.Errorf("unsupported type: %T", src)
 	}
-
-	return data, nil
 }