Преглед изворни кода

feature : responses whit context (#2637)

heyehang пре 2 година
родитељ
комит
a644ec7edd

+ 3 - 3
gateway/server.go

@@ -122,7 +122,7 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
 	return func(w http.ResponseWriter, r *http.Request) {
 		parser, err := internal.NewRequestParser(r, resolver)
 		if err != nil {
-			httpx.Error(w, err)
+			httpx.ErrorCtx(r.Context(), w, err)
 			return
 		}
 
@@ -134,12 +134,12 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
 		handler := internal.NewEventHandler(w, resolver)
 		if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header),
 			handler, parser.Next); err != nil {
-			httpx.Error(w, err)
+			httpx.ErrorCtx(r.Context(), w, err)
 		}
 
 		st := handler.Status
 		if st.Code() != codes.OK {
-			httpx.Error(w, st.Err())
+			httpx.ErrorCtx(r.Context(), w, st.Err())
 		}
 	}
 }

+ 1 - 1
rest/handler/timeouthandler.go

@@ -99,7 +99,7 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		defer tw.mu.Unlock()
 		// there isn't any user-defined middleware before TimoutHandler,
 		// so we can guarantee that cancelation in biz related code won't come here.
-		httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) {
+		httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
 			if errors.Is(err, context.Canceled) {
 				w.WriteHeader(statusClientClosedRequest)
 			} else {

+ 72 - 2
rest/httpx/responses.go

@@ -1,6 +1,7 @@
 package httpx
 
 import (
+	"context"
 	"encoding/json"
 	"net/http"
 	"sync"
@@ -11,8 +12,9 @@ import (
 )
 
 var (
-	errorHandler func(error) (int, interface{})
-	lock         sync.RWMutex
+	errorHandler    func(error) (int, interface{})
+	lock            sync.RWMutex
+	errorHandlerCtx func(context.Context, error) (int, interface{})
 )
 
 // Error writes err into w.
@@ -87,3 +89,71 @@ func WriteJson(w http.ResponseWriter, code int, v interface{}) {
 		logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
 	}
 }
+
+// Error writes err into w.
+func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
+	lock.RLock()
+	handlerCtx := errorHandlerCtx
+	lock.RUnlock()
+
+	if handlerCtx == nil {
+		if len(fns) > 0 {
+			fns[0](w, err)
+		} else if errcode.IsGrpcError(err) {
+			// don't unwrap error and get status.Message(),
+			// it hides the rpc error headers.
+			http.Error(w, err.Error(), errcode.CodeFromGrpcError(err))
+		} else {
+			http.Error(w, err.Error(), http.StatusBadRequest)
+		}
+
+		return
+	}
+
+	code, body := handlerCtx(ctx, err)
+	if body == nil {
+		w.WriteHeader(code)
+		return
+	}
+
+	e, ok := body.(error)
+	if ok {
+		http.Error(w, e.Error(), code)
+	} else {
+		WriteJsonCtx(ctx, w, code, body)
+	}
+}
+
+// OkJson writes v into w with 200 OK.
+func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) {
+	WriteJsonCtx(ctx, w, http.StatusOK, v)
+}
+
+// WriteJson writes v as json string into w with code.
+func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) {
+	bs, err := json.Marshal(v)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+
+	w.Header().Set(ContentType, header.JsonContentType)
+	w.WriteHeader(code)
+
+	if n, err := w.Write(bs); err != nil {
+		// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
+		// so it's ignored here.
+		if err != http.ErrHandlerTimeout {
+			logx.WithContext(ctx).Errorf("write response failed, error: %s", err)
+		}
+	} else if n < len(bs) {
+		logx.WithContext(ctx).Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
+	}
+}
+
+// SetErrorHandler sets the error handler, which is called on calling Error.
+func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) {
+	lock.Lock()
+	defer lock.Unlock()
+	errorHandlerCtx = handlerCtx
+}

+ 113 - 0
rest/httpx/responses_test.go

@@ -1,6 +1,7 @@
 package httpx
 
 import (
+	"context"
 	"errors"
 	"net/http"
 	"strings"
@@ -214,3 +215,115 @@ func (w *tracedResponseWriter) WriteHeader(code int) {
 	w.wroteHeader = true
 	w.code = code
 }
+
+func TestErrorCtx(t *testing.T) {
+	const (
+		body        = "foo"
+		wrappedBody = `"foo"`
+	)
+
+	tests := []struct {
+		name            string
+		input           string
+		errorHandlerCtx func(context.Context, error) (int, interface{})
+		expectHasBody   bool
+		expectBody      string
+		expectCode      int
+	}{
+		{
+			name:          "default error handler",
+			input:         body,
+			expectHasBody: true,
+			expectBody:    body,
+			expectCode:    http.StatusBadRequest,
+		},
+		{
+			name:  "customized error handler return string",
+			input: body,
+			errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) {
+				return http.StatusForbidden, err.Error()
+			},
+			expectHasBody: true,
+			expectBody:    wrappedBody,
+			expectCode:    http.StatusForbidden,
+		},
+		{
+			name:  "customized error handler return error",
+			input: body,
+			errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) {
+				return http.StatusForbidden, err
+			},
+			expectHasBody: true,
+			expectBody:    body,
+			expectCode:    http.StatusForbidden,
+		},
+		{
+			name:  "customized error handler return nil",
+			input: body,
+			errorHandlerCtx: func(context.Context, error) (int, interface{}) {
+				return http.StatusForbidden, nil
+			},
+			expectHasBody: false,
+			expectBody:    "",
+			expectCode:    http.StatusForbidden,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			w := tracedResponseWriter{
+				headers: make(map[string][]string),
+			}
+			if test.errorHandlerCtx != nil {
+				lock.RLock()
+				prev := errorHandlerCtx
+				lock.RUnlock()
+				SetErrorHandlerCtx(test.errorHandlerCtx)
+				defer func() {
+					lock.Lock()
+					test.errorHandlerCtx = prev
+					lock.Unlock()
+				}()
+			}
+			ErrorCtx(context.Background(), &w, errors.New(test.input))
+			assert.Equal(t, test.expectCode, w.code)
+			assert.Equal(t, test.expectHasBody, w.hasBody)
+			assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
+		})
+	}
+
+	//The current handler is a global event,Set default values to avoid impacting subsequent unit tests
+	SetErrorHandlerCtx(nil)
+}
+
+func TestErrorWithGrpcErrorCtx(t *testing.T) {
+	w := tracedResponseWriter{
+		headers: make(map[string][]string),
+	}
+	ErrorCtx(context.Background(), &w, status.Error(codes.Unavailable, "foo"))
+	assert.Equal(t, http.StatusServiceUnavailable, w.code)
+	assert.True(t, w.hasBody)
+	assert.True(t, strings.Contains(w.builder.String(), "foo"))
+}
+
+func TestErrorWithHandlerCtx(t *testing.T) {
+	w := tracedResponseWriter{
+		headers: make(map[string][]string),
+	}
+	ErrorCtx(context.Background(), &w, errors.New("foo"), func(w http.ResponseWriter, err error) {
+		http.Error(w, err.Error(), 499)
+	})
+	assert.Equal(t, 499, w.code)
+	assert.True(t, w.hasBody)
+	assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
+}
+
+func TestWriteJsonCtxMarshalFailed(t *testing.T) {
+	w := tracedResponseWriter{
+		headers: make(map[string][]string),
+	}
+	WriteJsonCtx(context.Background(), &w, http.StatusOK, map[string]interface{}{
+		"Data": complex(0, 0),
+	})
+	assert.Equal(t, http.StatusInternalServerError, w.code)
+}

+ 3 - 3
tools/goctl/api/gogen/handler.tpl

@@ -11,16 +11,16 @@ func {{.HandlerName}}(svcCtx *svc.ServiceContext) http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
 		{{if .HasRequest}}var req types.{{.RequestType}}
 		if err := httpx.Parse(r, &req); err != nil {
-			httpx.Error(w, err)
+			httpx.ErrorCtx(r.Context(), w, err)
 			return
 		}
 
 		{{end}}l := {{.LogicName}}.New{{.LogicType}}(r.Context(), svcCtx)
 		{{if .HasResp}}resp, {{end}}err := l.{{.Call}}({{if .HasRequest}}&req{{end}})
 		if err != nil {
-			httpx.Error(w, err)
+			httpx.ErrorCtx(r.Context(), w, err)
 		} else {
-			{{if .HasResp}}httpx.OkJson(w, resp){{else}}httpx.Ok(w){{end}}
+			{{if .HasResp}}httpx.OkJsonCtx(r.Context(), w, resp){{else}}httpx.Ok(w){{end}}
 		}
 	}
 }