Bläddra i källkod

refactor: guard timeout on API files (#1726)

Kevin Wan 3 år sedan
förälder
incheckning
2b9fc26c38
5 ändrade filer med 44 tillägg och 31 borttagningar
  1. 2 1
      core/stores/sqlx/utils.go
  2. 8 9
      rest/engine.go
  3. 7 7
      rest/server.go
  4. 7 0
      rest/server_test.go
  5. 20 14
      tools/goctl/api/gogen/genroutes.go

+ 2 - 1
core/stores/sqlx/utils.go

@@ -75,6 +75,7 @@ func format(query string, args ...interface{}) (string, error) {
 					break
 				}
 			}
+
 			if j > i+1 {
 				index, err := strconv.Atoi(query[i+1 : j])
 				if err != nil {
@@ -85,7 +86,7 @@ func format(query string, args ...interface{}) (string, error) {
 				if index > argIndex {
 					argIndex = index
 				}
-				
+
 				index--
 				if index < 0 || numArgs <= index {
 					return "", fmt.Errorf("error: wrong index %d in sql", index)

+ 8 - 9
rest/engine.go

@@ -119,16 +119,7 @@ func (ng *engine) bindRoutes(router httpx.Router) error {
 	return nil
 }
 
-func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
-	if timeout > 0 {
-		return timeout
-	}
-
-	return time.Duration(ng.conf.Timeout) * time.Millisecond
-}
-
 func (ng *engine) checkedMaxBytes(bytes int64) int64 {
-
 	if bytes > 0 {
 		return bytes
 	}
@@ -136,6 +127,14 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
 	return ng.conf.MaxBytes
 }
 
+func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
+	if timeout > 0 {
+		return timeout
+	}
+
+	return time.Duration(ng.conf.Timeout) * time.Millisecond
+}
+
 func (ng *engine) createMetrics() *stat.Metrics {
 	var metrics *stat.Metrics
 

+ 7 - 7
rest/server.go

@@ -137,6 +137,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption {
 	}
 }
 
+// WithMaxBytes returns a RouteOption to set maxBytes with the given value.
+func WithMaxBytes(maxBytes int64) RouteOption {
+	return func(r *featuredRoutes) {
+		r.maxBytes = maxBytes
+	}
+}
+
 // WithMiddlewares adds given middlewares to given routes.
 func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
 	for i := len(ms) - 1; i >= 0; i-- {
@@ -223,13 +230,6 @@ func WithTimeout(timeout time.Duration) RouteOption {
 	}
 }
 
-// WithMaxBytes returns a RouteOption to set maxBytes with given value.
-func WithMaxBytes(maxBytes int64) RouteOption {
-	return func(r *featuredRoutes) {
-		r.maxBytes = maxBytes
-	}
-}
-
 // WithTLSConfig returns a RunOption that with given tls config.
 func WithTLSConfig(cfg *tls.Config) RunOption {
 	return func(svr *Server) {

+ 7 - 0
rest/server_test.go

@@ -95,6 +95,13 @@ Port: 54321
 	}
 }
 
+func TestWithMaxBytes(t *testing.T) {
+	const maxBytes = 1000
+	var fr featuredRoutes
+	WithMaxBytes(maxBytes)(&fr)
+	assert.Equal(t, int64(maxBytes), fr.maxBytes)
+}
+
 func TestWithMiddleware(t *testing.T) {
 	m := make(map[string]string)
 	rt := router.NewRouter()

+ 20 - 14
tools/goctl/api/gogen/genroutes.go

@@ -24,7 +24,8 @@ const (
 package handler
 
 import (
-	"net/http"
+	"net/http"{{if .hasTimeout}}
+	"time"{{end}}
 
 	{{.importPackages}}
 )
@@ -38,6 +39,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
 		{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}}
 	)
 `
+	timeoutThreshold = time.Millisecond
 )
 
 var mapping = map[string]string{
@@ -59,7 +61,6 @@ type (
 		signatureEnabled bool
 		authName         string
 		timeout          string
-		timeoutEnable    bool
 		middlewares      []string
 		prefix           string
 		jwtTrans         string
@@ -83,6 +84,7 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error
 		return err
 	}
 
+	var hasTimeout bool
 	gt := template.Must(template.New("groupTemplate").Parse(templateText))
 	for _, g := range groups {
 		var gbuilder strings.Builder
@@ -114,12 +116,19 @@ rest.WithPrefix("%s"),`, g.prefix)
 		}
 
 		var timeout string
-		if g.timeoutEnable {
+		if len(g.timeout) > 0 {
 			duration, err := time.ParseDuration(g.timeout)
 			if err != nil {
-				panic(err)
+				return err
 			}
-			timeout = fmt.Sprintf("rest.WithTimeout(%d),", duration)
+
+			// why we check this, maybe some users set value 1, it's 1ns, not 1s.
+			if duration < timeoutThreshold {
+				return fmt.Errorf("timeout should not less than 1ms, now %v", duration)
+			}
+
+			timeout = fmt.Sprintf("rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
+			hasTimeout = true
 		}
 
 		var routes string
@@ -152,8 +161,8 @@ rest.WithPrefix("%s"),`, g.prefix)
 	if err != nil {
 		return err
 	}
-	routeFilename = routeFilename + ".go"
 
+	routeFilename = routeFilename + ".go"
 	filename := path.Join(dir, handlerDir, routeFilename)
 	os.Remove(filename)
 
@@ -165,7 +174,8 @@ rest.WithPrefix("%s"),`, g.prefix)
 		category:        category,
 		templateFile:    routesTemplateFile,
 		builtinTemplate: routesTemplate,
-		data: map[string]string{
+		data: map[string]interface{}{
+			"hasTimeout":      hasTimeout,
 			"importPackages":  genRouteImports(rootPkg, api),
 			"routesAdditions": strings.TrimSpace(builder.String()),
 		},
@@ -184,7 +194,8 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
 					continue
 				}
 			}
-			importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder), pathx.JoinPackages(parentPkg, handlerDir, folder)))
+			importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder),
+				pathx.JoinPackages(parentPkg, handlerDir, folder)))
 		}
 	}
 	imports := importSet.KeysStr()
@@ -218,12 +229,7 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
 			})
 		}
 
-		timeout := g.GetAnnotation("timeout")
-
-		if len(timeout) > 0 {
-			groupedRoutes.timeoutEnable = true
-			groupedRoutes.timeout = timeout
-		}
+		groupedRoutes.timeout = g.GetAnnotation("timeout")
 
 		jwt := g.GetAnnotation("jwt")
 		if len(jwt) > 0 {