Browse Source

chore: simplify tests with logtest (#3184)

Kevin Wan 2 years ago
parent
commit
14caf5c799

+ 1 - 0
.codecov.yml

@@ -6,3 +6,4 @@ ignore:
   - "tools"
   - "**/mock"
   - "**/*_mock.go"
+  - "**/*test"

+ 19 - 108
core/logc/logs_test.go

@@ -1,7 +1,6 @@
 package logc
 
 import (
-	"bytes"
 	"context"
 	"encoding/json"
 	"fmt"
@@ -11,14 +10,11 @@ import (
 
 	"github.com/stretchr/testify/assert"
 	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 )
 
 func TestAddGlobalFields(t *testing.T) {
-	var buf bytes.Buffer
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
+	buf := logtest.NewCollector(t)
 
 	Info(context.Background(), "hello")
 	buf.Reset()
@@ -34,155 +30,90 @@ func TestAddGlobalFields(t *testing.T) {
 }
 
 func TestAlert(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	Alert(context.Background(), "foo")
 	assert.True(t, strings.Contains(buf.String(), "foo"), buf.String())
 }
 
 func TestError(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Error(context.Background(), "foo")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
 }
 
 func TestErrorf(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Errorf(context.Background(), "foo %s", "bar")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
 }
 
 func TestErrorv(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Errorv(context.Background(), "foo")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
 }
 
 func TestErrorw(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Errorw(context.Background(), "foo", Field("a", "b"))
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
 }
 
 func TestInfo(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Info(context.Background(), "foo")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
 }
 
 func TestInfof(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Infof(context.Background(), "foo %s", "bar")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
 }
 
 func TestInfov(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Infov(context.Background(), "foo")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
 }
 
 func TestInfow(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Infow(context.Background(), "foo", Field("a", "b"))
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
 }
 
 func TestDebug(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Debug(context.Background(), "foo")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
 }
 
 func TestDebugf(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Debugf(context.Background(), "foo %s", "bar")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
 }
 
 func TestDebugv(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Debugv(context.Background(), "foo")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
 }
 
 func TestDebugw(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Debugw(context.Background(), "foo", Field("a", "b"))
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
@@ -204,48 +135,28 @@ func TestMisc(t *testing.T) {
 }
 
 func TestSlow(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Slow(context.Background(), "foo")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
 }
 
 func TestSlowf(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Slowf(context.Background(), "foo %s", "bar")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
 }
 
 func TestSlowv(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Slowv(context.Background(), "foo")
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
 }
 
 func TestSloww(t *testing.T) {
-	var buf strings.Builder
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
-
+	buf := logtest.NewCollector(t)
 	file, line := getFileLine()
 	Sloww(context.Background(), "foo", Field("a", "b"))
 	assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())

+ 78 - 0
core/logx/logtest/logtest.go

@@ -0,0 +1,78 @@
+package logtest
+
+import (
+	"bytes"
+	"encoding/json"
+	"io"
+	"testing"
+
+	"github.com/zeromicro/go-zero/core/logx"
+)
+
+type Buffer struct {
+	buf *bytes.Buffer
+	t   *testing.T
+}
+
+func Discard(t *testing.T) {
+	prev := logx.Reset()
+	logx.SetWriter(logx.NewWriter(io.Discard))
+
+	t.Cleanup(func() {
+		logx.SetWriter(prev)
+	})
+}
+
+func NewCollector(t *testing.T) *Buffer {
+	var buf bytes.Buffer
+	writer := logx.NewWriter(&buf)
+	prev := logx.Reset()
+	logx.SetWriter(writer)
+
+	t.Cleanup(func() {
+		logx.SetWriter(prev)
+	})
+
+	return &Buffer{
+		buf: &buf,
+		t:   t,
+	}
+}
+
+func (b *Buffer) Bytes() []byte {
+	return b.buf.Bytes()
+}
+
+func (b *Buffer) Content() string {
+	var m map[string]interface{}
+	if err := json.Unmarshal(b.buf.Bytes(), &m); err != nil {
+		b.t.Error(err)
+		return ""
+	}
+
+	content, ok := m["content"]
+	if !ok {
+		return ""
+	}
+
+	switch val := content.(type) {
+	case string:
+		return val
+	default:
+		bs, err := json.Marshal(content)
+		if err != nil {
+			b.t.Error(err)
+			return ""
+		}
+
+		return string(bs)
+	}
+}
+
+func (b *Buffer) Reset() {
+	b.buf.Reset()
+}
+
+func (b *Buffer) String() string {
+	return b.buf.String()
+}

+ 22 - 0
core/logx/logtest/logtest_test.go

@@ -0,0 +1,22 @@
+package logtest
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/zeromicro/go-zero/core/logx"
+)
+
+func TestCollector(t *testing.T) {
+	const input = "hello"
+	c := NewCollector(t)
+	logx.Info(input)
+	assert.Equal(t, input, c.Content())
+	assert.Contains(t, c.String(), input)
+}
+
+func TestDiscard(t *testing.T) {
+	const input = "hello"
+	Discard(t)
+	logx.Info(input)
+}

+ 2 - 10
core/proc/goroutines_test.go

@@ -5,19 +5,11 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
-	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 )
 
 func TestDumpGoroutines(t *testing.T) {
-	var buf strings.Builder
-	w := logx.NewWriter(&buf)
-	o := logx.Reset()
-	logx.SetWriter(w)
-	defer func() {
-		logx.Reset()
-		logx.SetWriter(o)
-	}()
-
+	buf := logtest.NewCollector(t)
 	dumpGoroutines()
 	assert.True(t, strings.Contains(buf.String(), ".dump"))
 }

+ 3 - 12
core/proc/profile_test.go

@@ -5,25 +5,16 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
-	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 )
 
 func TestProfile(t *testing.T) {
-	var buf strings.Builder
-	w := logx.NewWriter(&buf)
-	o := logx.Reset()
-	logx.SetWriter(w)
-
-	defer func() {
-		logx.Reset()
-		logx.SetWriter(o)
-	}()
-
+	c := logtest.NewCollector(t)
 	profiler := StartProfile()
 	// start again should not work
 	assert.NotNil(t, StartProfile())
 	profiler.Stop()
 	// stop twice
 	profiler.Stop()
-	assert.True(t, strings.Contains(buf.String(), ".pprof"))
+	assert.True(t, strings.Contains(c.String(), ".pprof"))
 }

+ 3 - 8
core/stat/usage_test.go

@@ -1,12 +1,11 @@
 package stat
 
 import (
-	"bytes"
 	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
-	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 )
 
 func TestBToMb(t *testing.T) {
@@ -41,15 +40,11 @@ func TestBToMb(t *testing.T) {
 }
 
 func TestPrintUsage(t *testing.T) {
-	var buf bytes.Buffer
-	writer := logx.NewWriter(&buf)
-	old := logx.Reset()
-	logx.SetWriter(writer)
-	defer logx.SetWriter(old)
+	c := logtest.NewCollector(t)
 
 	printUsage()
 
-	output := buf.String()
+	output := c.String()
 	assert.Contains(t, output, "CPU:")
 	assert.Contains(t, output, "MEMORY:")
 	assert.Contains(t, output, "Alloc=")

+ 2 - 11
core/stores/mon/collection_test.go

@@ -3,12 +3,11 @@ package mon
 import (
 	"context"
 	"errors"
-	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
 	"github.com/zeromicro/go-zero/core/breaker"
-	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 	"github.com/zeromicro/go-zero/core/stringx"
 	"github.com/zeromicro/go-zero/core/timex"
 	"go.mongodb.org/mongo-driver/bson"
@@ -573,15 +572,7 @@ func TestDecoratedCollection_LogDuration(t *testing.T) {
 		brk:        breaker.NewBreaker(),
 	}
 
-	var buf strings.Builder
-	w := logx.NewWriter(&buf)
-	o := logx.Reset()
-	logx.SetWriter(w)
-
-	defer func() {
-		logx.Reset()
-		logx.SetWriter(o)
-	}()
+	buf := logtest.NewCollector(t)
 
 	buf.Reset()
 	c.logDuration(context.Background(), "foo", timex.Now(), nil, "bar")

+ 2 - 11
core/stores/mon/util_test.go

@@ -3,12 +3,11 @@ package mon
 import (
 	"context"
 	"errors"
-	"strings"
 	"testing"
 	"time"
 
 	"github.com/stretchr/testify/assert"
-	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 )
 
 func TestFormatAddrs(t *testing.T) {
@@ -40,15 +39,7 @@ func TestFormatAddrs(t *testing.T) {
 }
 
 func Test_logDuration(t *testing.T) {
-	var buf strings.Builder
-	w := logx.NewWriter(&buf)
-	o := logx.Reset()
-	logx.SetWriter(w)
-
-	defer func() {
-		logx.Reset()
-		logx.SetWriter(o)
-	}()
+	buf := logtest.NewCollector(t)
 
 	buf.Reset()
 	logDuration(context.Background(), "foo", "bar", time.Millisecond, nil)

+ 6 - 23
core/stores/redis/hook_test.go

@@ -9,7 +9,7 @@ import (
 
 	red "github.com/go-redis/redis/v8"
 	"github.com/stretchr/testify/assert"
-	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 	ztrace "github.com/zeromicro/go-zero/core/trace"
 	tracesdk "go.opentelemetry.io/otel/trace"
 )
@@ -47,8 +47,7 @@ func TestHookProcessCase2(t *testing.T) {
 	})
 	defer ztrace.StopAgent()
 
-	w, restore := injectLog()
-	defer restore()
+	w := logtest.NewCollector(t)
 
 	ctx, err := durationHook.BeforeProcess(context.Background(), red.NewCmd(context.Background()))
 	if err != nil {
@@ -115,8 +114,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
 	})
 	defer ztrace.StopAgent()
 
-	w, restore := injectLog()
-	defer restore()
+	w := logtest.NewCollector(t)
 
 	ctx, err := durationHook.BeforeProcessPipeline(context.Background(), []red.Cmder{
 		red.NewCmd(context.Background()),
@@ -135,8 +133,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
 }
 
 func TestHookProcessPipelineCase3(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
+	w := logtest.NewCollector(t)
 
 	assert.Nil(t, durationHook.AfterProcessPipeline(context.Background(), []red.Cmder{
 		red.NewCmd(context.Background()),
@@ -145,8 +142,7 @@ func TestHookProcessPipelineCase3(t *testing.T) {
 }
 
 func TestHookProcessPipelineCase4(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
+	w := logtest.NewCollector(t)
 
 	ctx := context.WithValue(context.Background(), startTimeKey, "foo")
 	assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
@@ -169,8 +165,7 @@ func TestHookProcessPipelineCase5(t *testing.T) {
 }
 
 func TestLogDuration(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
+	w := logtest.NewCollector(t)
 
 	logDuration(context.Background(), []red.Cmder{
 		red.NewCmd(context.Background(), "get", "foo"),
@@ -183,15 +178,3 @@ func TestLogDuration(t *testing.T) {
 	}, 1*time.Second)
 	assert.True(t, strings.Contains(w.String(), `get foo\nset bar 0`))
 }
-
-func injectLog() (r *strings.Builder, restore func()) {
-	var buf strings.Builder
-	w := logx.NewWriter(&buf)
-	o := logx.Reset()
-	logx.SetWriter(w)
-
-	return &buf, func() {
-		logx.Reset()
-		logx.SetWriter(o)
-	}
-}

+ 1 - 0
core/trace/tracetest/tracetest.go

@@ -16,5 +16,6 @@ func NewInMemoryExporter(t *testing.T) *tracetest.InMemoryExporter {
 		me.Reset()
 	})
 	otel.SetTracerProvider(trace.NewTracerProvider(trace.WithSyncer(me)))
+
 	return me
 }

+ 0 - 2
rest/handler/breakerhandler_test.go

@@ -7,12 +7,10 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
-	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/stat"
 )
 
 func init() {
-	logx.Disable()
 	stat.SetReporter(nil)
 }
 

+ 0 - 4
rest/handler/contentsecurityhandler_test.go

@@ -62,10 +62,6 @@ type requestSettings struct {
 	signature   string
 }
 
-func init() {
-	log.SetOutput(io.Discard)
-}
-
 func TestContentSecurityHandler(t *testing.T) {
 	tests := []struct {
 		method      string

+ 0 - 5
rest/handler/cryptionhandler_test.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"encoding/base64"
 	"io"
-	"log"
 	"math/rand"
 	"net/http"
 	"net/http/httptest"
@@ -21,10 +20,6 @@ const (
 
 var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
 
-func init() {
-	log.SetOutput(io.Discard)
-}
-
 func TestCryptionHandlerGet(t *testing.T) {
 	req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
 	handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

+ 0 - 5
rest/handler/loghandler_test.go

@@ -3,7 +3,6 @@ package handler
 import (
 	"bytes"
 	"io"
-	"log"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -14,10 +13,6 @@ import (
 	"github.com/zeromicro/go-zero/rest/internal/response"
 )
 
-func init() {
-	log.SetOutput(io.Discard)
-}
-
 func TestLogHandler(t *testing.T) {
 	handlers := []func(handler http.Handler) http.Handler{
 		LogHandler,

+ 0 - 6
rest/handler/maxconnshandler_test.go

@@ -1,8 +1,6 @@
 package handler
 
 import (
-	"io"
-	"log"
 	"net/http"
 	"net/http/httptest"
 	"sync"
@@ -14,10 +12,6 @@ import (
 
 const conns = 4
 
-func init() {
-	log.SetOutput(io.Discard)
-}
-
 func TestMaxConnsHandler(t *testing.T) {
 	var waitGroup sync.WaitGroup
 	waitGroup.Add(conns)

+ 0 - 6
rest/handler/recoverhandler_test.go

@@ -1,8 +1,6 @@
 package handler
 
 import (
-	"io"
-	"log"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -10,10 +8,6 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-func init() {
-	log.SetOutput(io.Discard)
-}
-
 func TestWithPanic(t *testing.T) {
 	handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		panic("whatever")

+ 0 - 6
rest/handler/sheddinghandler_test.go

@@ -1,8 +1,6 @@
 package handler
 
 import (
-	"io"
-	"log"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -12,10 +10,6 @@ import (
 	"github.com/zeromicro/go-zero/core/stat"
 )
 
-func init() {
-	log.SetOutput(io.Discard)
-}
-
 func TestSheddingHandlerAccept(t *testing.T) {
 	metrics := stat.NewMetrics("unit-test")
 	shedder := mockShedder{

+ 8 - 6
rest/handler/timeouthandler.go

@@ -31,14 +31,14 @@ const (
 // Notice: even if canceled in server side, 499 will be logged as well.
 func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
 	return func(next http.Handler) http.Handler {
-		if duration > 0 {
-			return &timeoutHandler{
-				handler: next,
-				dt:      duration,
-			}
+		if duration <= 0 {
+			return next
 		}
 
-		return next
+		return &timeoutHandler{
+			handler: next,
+			dt:      duration,
+		}
 	}
 }
 
@@ -207,9 +207,11 @@ func relevantCaller() runtime.Frame {
 		if !strings.HasPrefix(frame.Function, "net/http.") {
 			return frame
 		}
+
 		if !more {
 			break
 		}
 	}
+
 	return frame
 }

+ 43 - 14
rest/handler/timeouthandler_test.go

@@ -2,21 +2,16 @@ package handler
 
 import (
 	"context"
-	"io"
-	"log"
 	"net/http"
 	"net/http/httptest"
 	"testing"
 	"time"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 	"github.com/zeromicro/go-zero/rest/internal/response"
 )
 
-func init() {
-	log.SetOutput(io.Discard)
-}
-
 func TestTimeout(t *testing.T) {
 	timeoutHandler := TimeoutHandler(time.Millisecond)
 	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -45,7 +40,12 @@ func TestWithTimeoutTimedout(t *testing.T) {
 	timeoutHandler := TimeoutHandler(time.Millisecond)
 	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		time.Sleep(time.Millisecond * 10)
-		w.Write([]byte(`foo`))
+		_, err := w.Write([]byte(`foo`))
+		if err != nil {
+			w.WriteHeader(http.StatusInternalServerError)
+			return
+		}
+
 		w.WriteHeader(http.StatusOK)
 	}))
 
@@ -96,7 +96,12 @@ func TestTimeoutWebsocket(t *testing.T) {
 func TestTimeoutWroteHeaderTwice(t *testing.T) {
 	timeoutHandler := TimeoutHandler(time.Minute)
 	handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.Write([]byte(`hello`))
+		_, err := w.Write([]byte(`hello`))
+		if err != nil {
+			w.WriteHeader(http.StatusInternalServerError)
+			return
+		}
+
 		w.Header().Set("foo", "bar")
 		w.WriteHeader(http.StatusOK)
 	}))
@@ -145,7 +150,7 @@ func TestTimeoutHijack(t *testing.T) {
 	}
 
 	assert.NotPanics(t, func() {
-		writer.Hijack()
+		_, _, _ = writer.Hijack()
 	})
 
 	writer = &timeoutWriter{
@@ -155,7 +160,7 @@ func TestTimeoutHijack(t *testing.T) {
 	}
 
 	assert.NotPanics(t, func() {
-		writer.Hijack()
+		_, _, _ = writer.Hijack()
 	})
 }
 
@@ -165,7 +170,7 @@ func TestTimeoutPusher(t *testing.T) {
 	}
 
 	assert.Panics(t, func() {
-		handler.Push("any", nil)
+		_ = handler.Push("any", nil)
 	})
 
 	handler = &timeoutWriter{
@@ -174,20 +179,44 @@ func TestTimeoutPusher(t *testing.T) {
 	assert.Equal(t, http.ErrNotSupported, handler.Push("any", nil))
 }
 
+func TestTimeoutWriter_Hijack(t *testing.T) {
+	writer := &timeoutWriter{
+		w:   httptest.NewRecorder(),
+		h:   make(http.Header),
+		req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
+	}
+	_, _, err := writer.Hijack()
+	assert.Error(t, err)
+}
+
+func TestTimeoutWroteTwice(t *testing.T) {
+	c := logtest.NewCollector(t)
+	writer := &timeoutWriter{
+		w: &response.WithCodeResponseWriter{
+			Writer: httptest.NewRecorder(),
+		},
+		h:   make(http.Header),
+		req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
+	}
+	writer.writeHeaderLocked(http.StatusOK)
+	writer.writeHeaderLocked(http.StatusOK)
+	assert.Contains(t, c.String(), "superfluous response.WriteHeader call")
+}
+
 type mockedPusher struct{}
 
 func (m mockedPusher) Header() http.Header {
 	panic("implement me")
 }
 
-func (m mockedPusher) Write(bytes []byte) (int, error) {
+func (m mockedPusher) Write(_ []byte) (int, error) {
 	panic("implement me")
 }
 
-func (m mockedPusher) WriteHeader(statusCode int) {
+func (m mockedPusher) WriteHeader(_ int) {
 	panic("implement me")
 }
 
-func (m mockedPusher) Push(target string, opts *http.PushOptions) error {
+func (m mockedPusher) Push(_ string, _ *http.PushOptions) error {
 	panic("implement me")
 }

+ 3 - 12
rest/internal/log_test.go

@@ -8,7 +8,7 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
-	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 )
 
 func TestInfo(t *testing.T) {
@@ -25,20 +25,11 @@ func TestInfo(t *testing.T) {
 }
 
 func TestError(t *testing.T) {
-	var buf strings.Builder
-	w := logx.NewWriter(&buf)
-	o := logx.Reset()
-	logx.SetWriter(w)
-
-	defer func() {
-		logx.Reset()
-		logx.SetWriter(o)
-	}()
-
+	c := logtest.NewCollector(t)
 	req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
 	Error(req, "first")
 	Errorf(req, "second %s", "third")
-	val := buf.String()
+	val := c.String()
 	assert.True(t, strings.Contains(val, "first"))
 	assert.True(t, strings.Contains(val, "second"))
 	assert.True(t, strings.Contains(val, "third"))

+ 2 - 4
rest/server_test.go

@@ -14,7 +14,7 @@ import (
 
 	"github.com/stretchr/testify/assert"
 	"github.com/zeromicro/go-zero/core/conf"
-	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 	"github.com/zeromicro/go-zero/rest/chain"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/internal/cors"
@@ -22,9 +22,7 @@ import (
 )
 
 func TestNewServer(t *testing.T) {
-	writer := logx.Reset()
-	defer logx.SetWriter(writer)
-	logx.SetWriter(logx.NewWriter(io.Discard))
+	logtest.Discard(t)
 
 	const configYaml = `
 Name: foo

+ 25 - 62
zrpc/internal/rpclogger_test.go

@@ -1,121 +1,96 @@
 package internal
 
 import (
-	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
-	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/core/logx/logtest"
 )
 
 const content = "foo"
 
 func TestLoggerError(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Error(content)
-	assert.Contains(t, w.String(), content)
+	assert.Contains(t, c.String(), content)
 }
 
 func TestLoggerErrorf(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Errorf(content)
-	assert.Contains(t, w.String(), content)
+	assert.Contains(t, c.String(), content)
 }
 
 func TestLoggerErrorln(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Errorln(content)
-	assert.Contains(t, w.String(), content)
+	assert.Contains(t, c.String(), content)
 }
 
 func TestLoggerFatal(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Fatal(content)
-	assert.Contains(t, w.String(), content)
+	assert.Contains(t, c.String(), content)
 }
 
 func TestLoggerFatalf(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Fatalf(content)
-	assert.Contains(t, w.String(), content)
+	assert.Contains(t, c.String(), content)
 }
 
 func TestLoggerFatalln(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Fatalln(content)
-	assert.Contains(t, w.String(), content)
+	assert.Contains(t, c.String(), content)
 }
 
 func TestLoggerInfo(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Info(content)
-	assert.Empty(t, w.String())
+	assert.Empty(t, c.String())
 }
 
 func TestLoggerInfof(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Infof(content)
-	assert.Empty(t, w.String())
+	assert.Empty(t, c.String())
 }
 
 func TestLoggerWarning(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Warning(content)
-	assert.Empty(t, w.String())
+	assert.Empty(t, c.String())
 }
 
 func TestLoggerInfoln(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Infoln(content)
-	assert.Empty(t, w.String())
+	assert.Empty(t, c.String())
 }
 
 func TestLoggerWarningf(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Warningf(content)
-	assert.Empty(t, w.String())
+	assert.Empty(t, c.String())
 }
 
 func TestLoggerWarningln(t *testing.T) {
-	w, restore := injectLog()
-	defer restore()
-
+	c := logtest.NewCollector(t)
 	logger := new(Logger)
 	logger.Warningln(content)
-	assert.Empty(t, w.String())
+	assert.Empty(t, c.String())
 }
 
 func TestLogger_V(t *testing.T) {
@@ -125,15 +100,3 @@ func TestLogger_V(t *testing.T) {
 	// grpclog.infoLog
 	assert.False(t, logger.V(0))
 }
-
-func injectLog() (r *strings.Builder, restore func()) {
-	var buf strings.Builder
-	w := logx.NewWriter(&buf)
-	o := logx.Reset()
-	logx.SetWriter(w)
-
-	return &buf, func() {
-		logx.Reset()
-		logx.SetWriter(o)
-	}
-}

+ 3 - 0
zrpc/internal/rpcserver_test.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"sync"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 	"github.com/zeromicro/go-zero/core/proc"
@@ -40,6 +41,8 @@ func TestRpcServer(t *testing.T) {
 	}()
 
 	wg.Wait()
+	time.Sleep(100 * time.Millisecond)
+
 	lock.Lock()
 	grpcServer.GracefulStop()
 	lock.Unlock()