浏览代码

chore: add more tests (#3018)

Kevin Wan 2 年之前
父节点
当前提交
60a13f1e53
共有 8 个文件被更改,包括 107 次插入39 次删除
  1. 6 1
      core/logx/logs.go
  2. 22 0
      core/logx/logs_test.go
  3. 2 2
      core/logx/rotatelogger.go
  4. 7 1
      core/logx/vars.go
  5. 11 7
      rest/engine.go
  6. 29 17
      rest/engine_test.go
  7. 11 9
      rest/server.go
  8. 19 2
      rest/server_test.go

+ 6 - 1
core/logx/logs.go

@@ -197,7 +197,12 @@ func Must(err error) {
 	msg := err.Error()
 	log.Print(msg)
 	getWriter().Severe(msg)
-	os.Exit(1)
+
+	if ExitOnFatal.True() {
+		os.Exit(1)
+	} else {
+		panic(msg)
+	}
 }
 
 // MustSetup sets up logging with given config c. It exits on error.

+ 22 - 0
core/logx/logs_test.go

@@ -24,6 +24,10 @@ var (
 	_    Writer = (*mockWriter)(nil)
 )
 
+func init() {
+	ExitOnFatal.Set(false)
+}
+
 type mockWriter struct {
 	lock    sync.Mutex
 	builder strings.Builder
@@ -208,6 +212,12 @@ func TestFileLineConsoleMode(t *testing.T) {
 	assert.True(t, w.Contains(fmt.Sprintf("%s:%d", file, line+1)))
 }
 
+func TestMust(t *testing.T) {
+	assert.Panics(t, func() {
+		Must(errors.New("foo"))
+	})
+}
+
 func TestStructedLogAlert(t *testing.T) {
 	w := new(mockWriter)
 	old := writer.Swap(w)
@@ -574,26 +584,38 @@ func TestSetup(t *testing.T) {
 		atomic.StoreUint32(&encoding, jsonEncodingType)
 	}()
 
+	setupOnce = sync.Once{}
+	MustSetup(LogConf{
+		ServiceName: "any",
+		Mode:        "console",
+		Encoding:    "json",
+		TimeFormat:  timeFormat,
+	})
+	setupOnce = sync.Once{}
 	MustSetup(LogConf{
 		ServiceName: "any",
 		Mode:        "console",
 		TimeFormat:  timeFormat,
 	})
+	setupOnce = sync.Once{}
 	MustSetup(LogConf{
 		ServiceName: "any",
 		Mode:        "file",
 		Path:        os.TempDir(),
 	})
+	setupOnce = sync.Once{}
 	MustSetup(LogConf{
 		ServiceName: "any",
 		Mode:        "volume",
 		Path:        os.TempDir(),
 	})
+	setupOnce = sync.Once{}
 	MustSetup(LogConf{
 		ServiceName: "any",
 		Mode:        "console",
 		TimeFormat:  timeFormat,
 	})
+	setupOnce = sync.Once{}
 	MustSetup(LogConf{
 		ServiceName: "any",
 		Mode:        "console",

+ 2 - 2
core/logx/rotatelogger.go

@@ -237,7 +237,7 @@ func NewLogger(filename string, rule RotateRule, compress bool) (*RotateLogger,
 		rule:     rule,
 		compress: compress,
 	}
-	if err := l.init(); err != nil {
+	if err := l.initialize(); err != nil {
 		return nil, err
 	}
 
@@ -281,7 +281,7 @@ func (l *RotateLogger) getBackupFilename() string {
 	return l.backup
 }
 
-func (l *RotateLogger) init() error {
+func (l *RotateLogger) initialize() error {
 	l.backup = l.rule.BackupFileName()
 
 	if fileInfo, err := os.Stat(l.filename); err != nil {

+ 7 - 1
core/logx/vars.go

@@ -1,6 +1,10 @@
 package logx
 
-import "errors"
+import (
+	"errors"
+
+	"github.com/zeromicro/go-zero/core/syncx"
+)
 
 const (
 	// DebugLevel logs everything
@@ -61,6 +65,8 @@ var (
 	ErrLogPathNotSet = errors.New("log path must be set")
 	// ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
 	ErrLogServiceNameNotSet = errors.New("log service name must be set")
+	// ExitOnFatal defines whether to exit on fatal errors, defined here to make it easier to test.
+	ExitOnFatal = syncx.ForAtomicBool(true)
 
 	truncatedField = Field(truncatedKey, true)
 )

+ 11 - 7
rest/engine.go

@@ -301,22 +301,26 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chai
 	}, nil
 }
 
-func (ng *engine) start(router httpx.Router, opts ...internal.StartOption) error {
+func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
 	if err := ng.bindRoutes(router); err != nil {
 		return err
 	}
 
-	opts = append(opts, ng.withTimeout())
+	// make sure user defined options overwrite default options
+	opts = append([]StartOption{ng.withTimeout()}, opts...)
 
 	if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
 		return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
 	}
 
-	opts = append(opts, func(svr *http.Server) {
-		if ng.tlsConfig != nil {
-			svr.TLSConfig = ng.tlsConfig
-		}
-	})
+	// make sure user defined options overwrite default options
+	opts = append([]StartOption{
+		func(svr *http.Server) {
+			if ng.tlsConfig != nil {
+				svr.TLSConfig = ng.tlsConfig
+			}
+		},
+	}, opts...)
 
 	return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
 		ng.conf.KeyFile, router, opts...)

+ 29 - 17
rest/engine_test.go

@@ -3,6 +3,7 @@ package rest
 import (
 	"context"
 	"errors"
+	"fmt"
 	"net/http"
 	"net/http/httptest"
 	"sync/atomic"
@@ -17,18 +18,21 @@ import (
 func TestNewEngine(t *testing.T) {
 	yamls := []string{
 		`Name: foo
-Port: 54321
+Host: localhost
+Port: 0
 Middlewares:
   Log: false
 `,
 		`Name: foo
-Port: 54321
+Host: localhost
+Port: 0
 CpuThreshold: 500
 Middlewares:
   Log: false
 `,
 		`Name: foo
-Port: 54321
+Host: localhost
+Port: 0
 CpuThreshold: 500
 Verbose: true
 `,
@@ -150,22 +154,29 @@ Verbose: true
 	}
 
 	for _, yaml := range yamls {
+		yaml := yaml
 		for _, route := range routes {
-			var cnf RestConf
-			assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
-			ng := newEngine(cnf)
-			ng.addRoutes(route)
-			ng.use(func(next http.HandlerFunc) http.HandlerFunc {
-				return func(w http.ResponseWriter, r *http.Request) {
-					next.ServeHTTP(w, r)
+			route := route
+			t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
+				var cnf RestConf
+				assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
+				ng := newEngine(cnf)
+				ng.addRoutes(route)
+				ng.use(func(next http.HandlerFunc) http.HandlerFunc {
+					return func(w http.ResponseWriter, r *http.Request) {
+						next.ServeHTTP(w, r)
+					}
+				})
+
+				assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
+				}))
+
+				timeout := time.Second * 3
+				if route.timeout > timeout {
+					timeout = route.timeout
 				}
+				assert.Equal(t, timeout, ng.timeout)
 			})
-			assert.NotNil(t, ng.start(mockedRouter{}))
-			timeout := time.Second * 3
-			if route.timeout > timeout {
-				timeout = route.timeout
-			}
-			assert.Equal(t, timeout, ng.timeout)
 		}
 	}
 }
@@ -340,7 +351,8 @@ func TestEngine_withTimeout(t *testing.T) {
 	}
 }
 
-type mockedRouter struct{}
+type mockedRouter struct {
+}
 
 func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
 }

+ 11 - 9
rest/server.go

@@ -2,7 +2,6 @@ package rest
 
 import (
 	"crypto/tls"
-	"log"
 	"net/http"
 	"path"
 	"time"
@@ -21,7 +20,7 @@ type (
 	RunOption func(*Server)
 
 	// StartOption defines the method to customize http server.
-	StartOption func(svr *http.Server)
+	StartOption = internal.StartOption
 
 	// A Server is a http server.
 	Server struct {
@@ -36,7 +35,7 @@ type (
 func MustNewServer(c RestConf, opts ...RunOption) *Server {
 	server, err := NewServer(c, opts...)
 	if err != nil {
-		log.Fatal(err)
+		logx.Must(err)
 	}
 
 	return server
@@ -116,12 +115,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 // Start starts the Server.
 // Graceful shutdown is enabled by default.
 // Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
-func (s *Server) Start(opts ...StartOption) {
-	var startOption []internal.StartOption
-	for _, opt := range opts {
-		startOption = append(startOption, internal.StartOption(opt))
-	}
-	handleError(s.ngin.start(s.router, startOption...))
+func (s *Server) Start() {
+	handleError(s.ngin.start(s.router))
+}
+
+// StartWithOpts starts the Server.
+// Graceful shutdown is enabled by default.
+// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
+func (s *Server) StartWithOpts(opts ...StartOption) {
+	handleError(s.ngin.start(s.router, opts...))
 }
 
 // Stop stops the Server.

+ 19 - 2
rest/server_test.go

@@ -28,7 +28,8 @@ func TestNewServer(t *testing.T) {
 
 	const configYaml = `
 Name: foo
-Port: 54321
+Host: localhost
+Port: 0
 `
 	var cnf RestConf
 	assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
@@ -101,6 +102,23 @@ Port: 54321
 			svr.Start()
 			svr.Stop()
 		}()
+
+		func() {
+			defer func() {
+				p := recover()
+				switch v := p.(type) {
+				case error:
+					assert.Equal(t, "foo", v.Error())
+				default:
+					t.Fail()
+				}
+			}()
+
+			svr.StartWithOpts(func(svr *http.Server) {
+				svr.RegisterOnShutdown(func() {})
+			})
+			svr.Stop()
+		}()
 	}
 }
 
@@ -569,7 +587,6 @@ Port: 54321
 			Method: http.MethodGet,
 			Path:   "/user/:name",
 			Handler: func(writer http.ResponseWriter, request *http.Request) {
-
 				var userInfo struct {
 					Name string `path:"name"`
 				}