Browse Source

feat: log 404 requests with traceid (#1554)

Kevin Wan 3 years ago
parent
commit
842656aa90

+ 22 - 0
rest/engine.go

@@ -14,6 +14,7 @@ import (
 	"github.com/zeromicro/go-zero/rest/handler"
 	"github.com/zeromicro/go-zero/rest/handler"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/internal"
 	"github.com/zeromicro/go-zero/rest/internal"
+	"github.com/zeromicro/go-zero/rest/internal/response"
 )
 )
 
 
 // use 1000m to represent 100%
 // use 1000m to represent 100%
@@ -154,6 +155,27 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
 	return ng.shedder
 	return ng.shedder
 }
 }
 
 
+// notFoundHandler returns a middleware that handles 404 not found requests.
+func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		chain := alice.New(
+			handler.TracingHandler(ng.conf.Name, ""),
+			ng.getLogHandler(),
+		)
+
+		var h http.Handler
+		if next != nil {
+			h = chain.Then(next)
+		} else {
+			h = chain.Then(http.NotFoundHandler())
+		}
+
+		cw := response.NewHeaderOnceResponseWriter(w)
+		h.ServeHTTP(cw, r)
+		cw.WriteHeader(http.StatusNotFound)
+	})
+}
+
 func (ng *engine) setTlsConfig(cfg *tls.Config) {
 func (ng *engine) setTlsConfig(cfg *tls.Config) {
 	ng.tlsConfig = cfg
 	ng.tlsConfig = cfg
 }
 }

+ 73 - 0
rest/engine_test.go

@@ -1,13 +1,17 @@
 package rest
 package rest
 
 
 import (
 import (
+	"context"
 	"errors"
 	"errors"
 	"net/http"
 	"net/http"
+	"net/http/httptest"
+	"sync/atomic"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"github.com/zeromicro/go-zero/core/conf"
 	"github.com/zeromicro/go-zero/core/conf"
+	"github.com/zeromicro/go-zero/core/logx"
 )
 )
 
 
 func TestNewEngine(t *testing.T) {
 func TestNewEngine(t *testing.T) {
@@ -190,6 +194,75 @@ func TestEngine_checkedTimeout(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestEngine_notFoundHandler(t *testing.T) {
+	logx.Disable()
+
+	ng := newEngine(RestConf{})
+	ts := httptest.NewServer(ng.notFoundHandler(nil))
+	defer ts.Close()
+
+	client := ts.Client()
+	err := func(ctx context.Context) error {
+		req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
+		assert.Nil(t, err)
+		res, err := client.Do(req)
+		assert.Nil(t, err)
+		assert.Equal(t, http.StatusNotFound, res.StatusCode)
+		return res.Body.Close()
+	}(context.Background())
+
+	assert.Nil(t, err)
+}
+
+func TestEngine_notFoundHandlerNotNil(t *testing.T) {
+	logx.Disable()
+
+	ng := newEngine(RestConf{})
+	var called int32
+	ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		atomic.AddInt32(&called, 1)
+	})))
+	defer ts.Close()
+
+	client := ts.Client()
+	err := func(ctx context.Context) error {
+		req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
+		assert.Nil(t, err)
+		res, err := client.Do(req)
+		assert.Nil(t, err)
+		assert.Equal(t, http.StatusNotFound, res.StatusCode)
+		return res.Body.Close()
+	}(context.Background())
+
+	assert.Nil(t, err)
+	assert.Equal(t, int32(1), atomic.LoadInt32(&called))
+}
+
+func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
+	logx.Disable()
+
+	ng := newEngine(RestConf{})
+	var called int32
+	ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		atomic.AddInt32(&called, 1)
+		w.WriteHeader(http.StatusExpectationFailed)
+	})))
+	defer ts.Close()
+
+	client := ts.Client()
+	err := func(ctx context.Context) error {
+		req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
+		assert.Nil(t, err)
+		res, err := client.Do(req)
+		assert.Nil(t, err)
+		assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
+		return res.Body.Close()
+	}(context.Background())
+
+	assert.Nil(t, err)
+	assert.Equal(t, int32(1), atomic.LoadInt32(&called))
+}
+
 type mockedRouter struct{}
 type mockedRouter struct{}
 
 
 func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {

+ 2 - 47
rest/handler/authhandler.go

@@ -1,15 +1,14 @@
 package handler
 package handler
 
 
 import (
 import (
-	"bufio"
 	"context"
 	"context"
 	"errors"
 	"errors"
-	"net"
 	"net/http"
 	"net/http"
 	"net/http/httputil"
 	"net/http/httputil"
 
 
 	"github.com/golang-jwt/jwt/v4"
 	"github.com/golang-jwt/jwt/v4"
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/logx"
+	"github.com/zeromicro/go-zero/rest/internal/response"
 	"github.com/zeromicro/go-zero/rest/token"
 	"github.com/zeromicro/go-zero/rest/token"
 )
 )
 
 
@@ -105,7 +104,7 @@ func detailAuthLog(r *http.Request, reason string) {
 }
 }
 
 
 func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) {
 func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) {
-	writer := newGuardedResponseWriter(w)
+	writer := response.NewHeaderOnceResponseWriter(w)
 
 
 	if err != nil {
 	if err != nil {
 		detailAuthLog(r, err.Error())
 		detailAuthLog(r, err.Error())
@@ -121,47 +120,3 @@ func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback Un
 	// if user not setting HTTP header, we set header with 401
 	// if user not setting HTTP header, we set header with 401
 	writer.WriteHeader(http.StatusUnauthorized)
 	writer.WriteHeader(http.StatusUnauthorized)
 }
 }
-
-type guardedResponseWriter struct {
-	writer      http.ResponseWriter
-	wroteHeader bool
-}
-
-func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
-	return &guardedResponseWriter{
-		writer: w,
-	}
-}
-
-func (grw *guardedResponseWriter) Flush() {
-	if flusher, ok := grw.writer.(http.Flusher); ok {
-		flusher.Flush()
-	}
-}
-
-func (grw *guardedResponseWriter) Header() http.Header {
-	return grw.writer.Header()
-}
-
-// Hijack implements the http.Hijacker interface.
-// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
-func (grw *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
-	if hijacked, ok := grw.writer.(http.Hijacker); ok {
-		return hijacked.Hijack()
-	}
-
-	return nil, nil, errors.New("server doesn't support hijacking")
-}
-
-func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
-	return grw.writer.Write(body)
-}
-
-func (grw *guardedResponseWriter) WriteHeader(statusCode int) {
-	if grw.wroteHeader {
-		return
-	}
-
-	grw.wroteHeader = true
-	grw.writer.WriteHeader(statusCode)
-}

+ 0 - 20
rest/handler/authhandler_test.go

@@ -90,26 +90,6 @@ func TestAuthHandler_NilError(t *testing.T) {
 	})
 	})
 }
 }
 
 
-func TestAuthHandler_Flush(t *testing.T) {
-	resp := httptest.NewRecorder()
-	handler := newGuardedResponseWriter(resp)
-	handler.Flush()
-	assert.True(t, resp.Flushed)
-}
-
-func TestAuthHandler_Hijack(t *testing.T) {
-	resp := httptest.NewRecorder()
-	writer := newGuardedResponseWriter(resp)
-	assert.NotPanics(t, func() {
-		writer.Hijack()
-	})
-
-	writer = newGuardedResponseWriter(mockedHijackable{resp})
-	assert.NotPanics(t, func() {
-		writer.Hijack()
-	})
-}
-
 func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
 func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
 	now := time.Now().Unix()
 	now := time.Now().Unix()
 	claims := make(jwt.MapClaims)
 	claims := make(jwt.MapClaims)

+ 2 - 2
rest/handler/breakerhandler.go

@@ -9,7 +9,7 @@ import (
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/stat"
 	"github.com/zeromicro/go-zero/core/stat"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/httpx"
-	"github.com/zeromicro/go-zero/rest/internal/security"
+	"github.com/zeromicro/go-zero/rest/internal/response"
 )
 )
 
 
 const breakerSeparator = "://"
 const breakerSeparator = "://"
@@ -28,7 +28,7 @@ func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handle
 				return
 				return
 			}
 			}
 
 
-			cw := &security.WithCodeResponseWriter{Writer: w}
+			cw := &response.WithCodeResponseWriter{Writer: w}
 			defer func() {
 			defer func() {
 				if cw.Code < http.StatusInternalServerError {
 				if cw.Code < http.StatusInternalServerError {
 					promise.Accept()
 					promise.Accept()

+ 2 - 2
rest/handler/prometheushandler.go

@@ -8,7 +8,7 @@ import (
 	"github.com/zeromicro/go-zero/core/metric"
 	"github.com/zeromicro/go-zero/core/metric"
 	"github.com/zeromicro/go-zero/core/prometheus"
 	"github.com/zeromicro/go-zero/core/prometheus"
 	"github.com/zeromicro/go-zero/core/timex"
 	"github.com/zeromicro/go-zero/core/timex"
-	"github.com/zeromicro/go-zero/rest/internal/security"
+	"github.com/zeromicro/go-zero/rest/internal/response"
 )
 )
 
 
 const serverNamespace = "http_server"
 const serverNamespace = "http_server"
@@ -41,7 +41,7 @@ func PrometheusHandler(path string) func(http.Handler) http.Handler {
 
 
 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			startTime := timex.Now()
 			startTime := timex.Now()
-			cw := &security.WithCodeResponseWriter{Writer: w}
+			cw := &response.WithCodeResponseWriter{Writer: w}
 			defer func() {
 			defer func() {
 				metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path)
 				metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path)
 				metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code))
 				metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code))

+ 2 - 2
rest/handler/sheddinghandler.go

@@ -8,7 +8,7 @@ import (
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/logx"
 	"github.com/zeromicro/go-zero/core/stat"
 	"github.com/zeromicro/go-zero/core/stat"
 	"github.com/zeromicro/go-zero/rest/httpx"
 	"github.com/zeromicro/go-zero/rest/httpx"
-	"github.com/zeromicro/go-zero/rest/internal/security"
+	"github.com/zeromicro/go-zero/rest/internal/response"
 )
 )
 
 
 const serviceType = "api"
 const serviceType = "api"
@@ -41,7 +41,7 @@ func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Hand
 				return
 				return
 			}
 			}
 
 
-			cw := &security.WithCodeResponseWriter{Writer: w}
+			cw := &response.WithCodeResponseWriter{Writer: w}
 			defer func() {
 			defer func() {
 				if cw.Code == http.StatusServiceUnavailable {
 				if cw.Code == http.StatusServiceUnavailable {
 					promise.Fail()
 					promise.Fail()

+ 6 - 2
rest/handler/tracinghandler.go

@@ -18,12 +18,16 @@ func TracingHandler(serviceName, path string) func(http.Handler) http.Handler {
 
 
 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
 			ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
+			spanName := path
+			if len(spanName) == 0 {
+				spanName = r.URL.Path
+			}
 			spanCtx, span := tracer.Start(
 			spanCtx, span := tracer.Start(
 				ctx,
 				ctx,
-				path,
+				spanName,
 				oteltrace.WithSpanKind(oteltrace.SpanKindServer),
 				oteltrace.WithSpanKind(oteltrace.SpanKindServer),
 				oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(
 				oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(
-					serviceName, path, r)...),
+					serviceName, spanName, r)...),
 			)
 			)
 			defer span.End()
 			defer span.End()
 
 

+ 28 - 24
rest/handler/tracinghandler_test.go

@@ -6,6 +6,7 @@ import (
 	"net/http/httptest"
 	"net/http/httptest"
 	"testing"
 	"testing"
 
 
+	"github.com/justinas/alice"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	ztrace "github.com/zeromicro/go-zero/core/trace"
 	ztrace "github.com/zeromicro/go-zero/core/trace"
 	"go.opentelemetry.io/otel"
 	"go.opentelemetry.io/otel"
@@ -21,28 +22,31 @@ func TestOtelHandler(t *testing.T) {
 		Sampler:  1.0,
 		Sampler:  1.0,
 	})
 	})
 
 
-	ts := httptest.NewServer(
-		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-			ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header))
-			spanCtx := trace.SpanContextFromContext(ctx)
-			assert.Equal(t, true, spanCtx.IsValid())
-		}),
-	)
-	defer ts.Close()
-
-	client := ts.Client()
-	err := func(ctx context.Context) error {
-		ctx, span := otel.Tracer("httptrace/client").Start(ctx, "test")
-		defer span.End()
-
-		req, _ := http.NewRequest("GET", ts.URL, nil)
-		otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header))
-
-		res, err := client.Do(req)
-		assert.Equal(t, err, nil)
-		_ = res.Body.Close()
-		return nil
-	}(context.Background())
-
-	assert.Equal(t, err, nil)
+	for _, test := range []string{"", "bar"} {
+		t.Run(test, func(t *testing.T) {
+			h := alice.New(TracingHandler("foo", test)).Then(
+				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+					ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header))
+					spanCtx := trace.SpanContextFromContext(ctx)
+					assert.True(t, spanCtx.IsValid())
+				}))
+			ts := httptest.NewServer(h)
+			defer ts.Close()
+
+			client := ts.Client()
+			err := func(ctx context.Context) error {
+				ctx, span := otel.Tracer("httptrace/client").Start(ctx, "test")
+				defer span.End()
+
+				req, _ := http.NewRequest("GET", ts.URL, nil)
+				otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header))
+
+				res, err := client.Do(req)
+				assert.Nil(t, err)
+				return res.Body.Close()
+			}(context.Background())
+
+			assert.Nil(t, err)
+		})
+	}
 }
 }

+ 3 - 42
rest/internal/cors/handlers.go

@@ -1,10 +1,9 @@
 package cors
 package cors
 
 
 import (
 import (
-	"bufio"
-	"errors"
-	"net"
 	"net/http"
 	"net/http"
+
+	"github.com/zeromicro/go-zero/rest/internal/response"
 )
 )
 
 
 const (
 const (
@@ -30,7 +29,7 @@ const (
 // At most one origin can be specified, other origins are ignored if given, default to be *.
 // At most one origin can be specified, other origins are ignored if given, default to be *.
 func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler {
 func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		gw := &guardedResponseWriter{w: w}
+		gw := response.NewHeaderOnceResponseWriter(w)
 		checkAndSetHeaders(gw, r, origins)
 		checkAndSetHeaders(gw, r, origins)
 		if fn != nil {
 		if fn != nil {
 			fn(gw)
 			fn(gw)
@@ -62,44 +61,6 @@ func Middleware(fn func(w http.Header), origins ...string) func(http.HandlerFunc
 	}
 	}
 }
 }
 
 
-type guardedResponseWriter struct {
-	w           http.ResponseWriter
-	wroteHeader bool
-}
-
-func (w *guardedResponseWriter) Flush() {
-	if flusher, ok := w.w.(http.Flusher); ok {
-		flusher.Flush()
-	}
-}
-
-func (w *guardedResponseWriter) Header() http.Header {
-	return w.w.Header()
-}
-
-// Hijack implements the http.Hijacker interface.
-// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
-func (w *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
-	if hijacked, ok := w.w.(http.Hijacker); ok {
-		return hijacked.Hijack()
-	}
-
-	return nil, nil, errors.New("server doesn't support hijacking")
-}
-
-func (w *guardedResponseWriter) Write(bytes []byte) (int, error) {
-	return w.w.Write(bytes)
-}
-
-func (w *guardedResponseWriter) WriteHeader(code int) {
-	if w.wroteHeader {
-		return
-	}
-
-	w.w.WriteHeader(code)
-	w.wroteHeader = true
-}
-
 func checkAndSetHeaders(w http.ResponseWriter, r *http.Request, origins []string) {
 func checkAndSetHeaders(w http.ResponseWriter, r *http.Request, origins []string) {
 	setVaryHeaders(w, r)
 	setVaryHeaders(w, r)
 
 

+ 0 - 47
rest/internal/cors/handlers_test.go

@@ -1,8 +1,6 @@
 package cors
 package cors
 
 
 import (
 import (
-	"bufio"
-	"net"
 	"net/http"
 	"net/http"
 	"net/http/httptest"
 	"net/http/httptest"
 	"testing"
 	"testing"
@@ -131,48 +129,3 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
-
-func TestGuardedResponseWriter_Flush(t *testing.T) {
-	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
-	handler := NotAllowedHandler(func(w http.ResponseWriter) {
-		w.Header().Set("X-Test", "test")
-		w.WriteHeader(http.StatusServiceUnavailable)
-		_, err := w.Write([]byte("content"))
-		assert.Nil(t, err)
-
-		flusher, ok := w.(http.Flusher)
-		assert.True(t, ok)
-		flusher.Flush()
-	}, "foo.com")
-
-	resp := httptest.NewRecorder()
-	handler.ServeHTTP(resp, req)
-	assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
-	assert.Equal(t, "test", resp.Header().Get("X-Test"))
-	assert.Equal(t, "content", resp.Body.String())
-}
-
-func TestGuardedResponseWriter_Hijack(t *testing.T) {
-	resp := httptest.NewRecorder()
-	writer := &guardedResponseWriter{
-		w: resp,
-	}
-	assert.NotPanics(t, func() {
-		writer.Hijack()
-	})
-
-	writer = &guardedResponseWriter{
-		w: mockedHijackable{resp},
-	}
-	assert.NotPanics(t, func() {
-		writer.Hijack()
-	})
-}
-
-type mockedHijackable struct {
-	*httptest.ResponseRecorder
-}
-
-func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
-	return nil, nil, nil
-}

+ 57 - 0
rest/internal/response/headeronceresponsewriter.go

@@ -0,0 +1,57 @@
+package response
+
+import (
+	"bufio"
+	"errors"
+	"net"
+	"net/http"
+)
+
+// HeaderOnceResponseWriter is a http.ResponseWriter implementation
+// that only the first WriterHeader takes effect.
+type HeaderOnceResponseWriter struct {
+	w           http.ResponseWriter
+	wroteHeader bool
+}
+
+// NewHeaderOnceResponseWriter returns a HeaderOnceResponseWriter.
+func NewHeaderOnceResponseWriter(w http.ResponseWriter) http.ResponseWriter {
+	return &HeaderOnceResponseWriter{w: w}
+}
+
+// Flush flushes the response writer.
+func (w *HeaderOnceResponseWriter) Flush() {
+	if flusher, ok := w.w.(http.Flusher); ok {
+		flusher.Flush()
+	}
+}
+
+// Header returns the http header.
+func (w *HeaderOnceResponseWriter) Header() http.Header {
+	return w.w.Header()
+}
+
+// Hijack implements the http.Hijacker interface.
+// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
+func (w *HeaderOnceResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	if hijacked, ok := w.w.(http.Hijacker); ok {
+		return hijacked.Hijack()
+	}
+
+	return nil, nil, errors.New("server doesn't support hijacking")
+}
+
+// Write writes bytes into w.
+func (w *HeaderOnceResponseWriter) Write(bytes []byte) (int, error) {
+	return w.w.Write(bytes)
+}
+
+// WriteHeader writes code into w, and not sealing the writer.
+func (w *HeaderOnceResponseWriter) WriteHeader(code int) {
+	if w.wroteHeader {
+		return
+	}
+
+	w.w.WriteHeader(code)
+	w.wroteHeader = true
+}

+ 58 - 0
rest/internal/response/headeronceresponsewriter_test.go

@@ -0,0 +1,58 @@
+package response
+
+import (
+	"bufio"
+	"net"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestHeaderOnceResponseWriter_Flush(t *testing.T) {
+	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
+	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		cw := NewHeaderOnceResponseWriter(w)
+		cw.Header().Set("X-Test", "test")
+		cw.WriteHeader(http.StatusServiceUnavailable)
+		cw.WriteHeader(http.StatusExpectationFailed)
+		_, err := cw.Write([]byte("content"))
+		assert.Nil(t, err)
+
+		flusher, ok := cw.(http.Flusher)
+		assert.True(t, ok)
+		flusher.Flush()
+	})
+
+	resp := httptest.NewRecorder()
+	handler.ServeHTTP(resp, req)
+	assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
+	assert.Equal(t, "test", resp.Header().Get("X-Test"))
+	assert.Equal(t, "content", resp.Body.String())
+}
+
+func TestHeaderOnceResponseWriter_Hijack(t *testing.T) {
+	resp := httptest.NewRecorder()
+	writer := &HeaderOnceResponseWriter{
+		w: resp,
+	}
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+
+	writer = &HeaderOnceResponseWriter{
+		w: mockedHijackable{resp},
+	}
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+}
+
+type mockedHijackable struct {
+	*httptest.ResponseRecorder
+}
+
+func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	return nil, nil, nil
+}

+ 7 - 2
rest/internal/security/withcoderesponsewriter.go → rest/internal/response/withcoderesponsewriter.go

@@ -1,7 +1,8 @@
-package security
+package response
 
 
 import (
 import (
 	"bufio"
 	"bufio"
+	"errors"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 )
 )
@@ -27,7 +28,11 @@ func (w *WithCodeResponseWriter) Header() http.Header {
 // Hijack implements the http.Hijacker interface.
 // Hijack implements the http.Hijacker interface.
 // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
 // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
 func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
 func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
-	return w.Writer.(http.Hijacker).Hijack()
+	if hijacked, ok := w.Writer.(http.Hijacker); ok {
+		return hijacked.Hijack()
+	}
+
+	return nil, nil, errors.New("server doesn't support hijacking")
 }
 }
 
 
 // Write writes bytes into w.
 // Write writes bytes into w.

+ 18 - 1
rest/internal/security/withcoderesponsewriter_test.go → rest/internal/response/withcoderesponsewriter_test.go

@@ -1,4 +1,4 @@
-package security
+package response
 
 
 import (
 import (
 	"net/http"
 	"net/http"
@@ -31,3 +31,20 @@ func TestWithCodeResponseWriter(t *testing.T) {
 	assert.Equal(t, "test", resp.Header().Get("X-Test"))
 	assert.Equal(t, "test", resp.Header().Get("X-Test"))
 	assert.Equal(t, "content", resp.Body.String())
 	assert.Equal(t, "content", resp.Body.String())
 }
 }
+
+func TestWithCodeResponseWriter_Hijack(t *testing.T) {
+	resp := httptest.NewRecorder()
+	writer := &WithCodeResponseWriter{
+		Writer: resp,
+	}
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+
+	writer = &WithCodeResponseWriter{
+		Writer: mockedHijackable{resp},
+	}
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+}

+ 3 - 1
rest/server.go

@@ -49,6 +49,7 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
 		router: router.NewRouter(),
 		router: router.NewRouter(),
 	}
 	}
 
 
+	opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...)
 	for _, opt := range opts {
 	for _, opt := range opts {
 		opt(server)
 		opt(server)
 	}
 	}
@@ -163,7 +164,8 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
 // WithNotFoundHandler returns a RunOption with not found handler set to given handler.
 // WithNotFoundHandler returns a RunOption with not found handler set to given handler.
 func WithNotFoundHandler(handler http.Handler) RunOption {
 func WithNotFoundHandler(handler http.Handler) RunOption {
 	return func(server *Server) {
 	return func(server *Server) {
-		server.router.SetNotFoundHandler(handler)
+		notFoundHandler := server.ngin.notFoundHandler(handler)
+		server.router.SetNotFoundHandler(notFoundHandler)
 	}
 	}
 }
 }