Răsfoiți Sursa

simplify timeoutinterceptor (#840)

Co-authored-by: chenmusheng <chenmusheng@laoyuegou.com>
masonchen2014 3 ani în urmă
părinte
comite
cb8d9d413a

+ 9 - 1
zrpc/client_test.go

@@ -6,6 +6,7 @@ import (
 	"log"
 	"net"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 	"github.com/tal-tech/go-zero/core/logx"
@@ -58,6 +59,13 @@ func TestDepositServer_Deposit(t *testing.T) {
 			codes.OK,
 			"",
 		},
+		{
+			"valid request with long handling time",
+			2000.00,
+			nil,
+			codes.DeadlineExceeded,
+			fmt.Sprintf("context deadline exceeded"),
+		},
 	}
 
 	directClient := MustNewClient(
@@ -79,7 +87,7 @@ func TestDepositServer_Deposit(t *testing.T) {
 			func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
 				invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
 				return invoker(ctx, method, req, reply, cc, opts...)
-			}))
+			}), WithTimeout(1000*time.Millisecond))
 	assert.Nil(t, err)
 	clients := []Client{
 		directClient,

+ 1 - 22
zrpc/internal/clientinterceptors/timeoutinterceptor.go

@@ -17,27 +17,6 @@ func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
 
 		ctx, cancel := context.WithTimeout(ctx, timeout)
 		defer cancel()
-
-		// create channel with buffer size 1 to avoid goroutine leak
-		done := make(chan error, 1)
-		panicChan := make(chan interface{}, 1)
-		go func() {
-			defer func() {
-				if p := recover(); p != nil {
-					panicChan <- p
-				}
-			}()
-
-			done <- invoker(ctx, method, req, reply, cc, opts...)
-		}()
-
-		select {
-		case p := <-panicChan:
-			panic(p)
-		case err := <-done:
-			return err
-		case <-ctx.Done():
-			return ctx.Err()
-		}
+		return invoker(ctx, method, req, reply, cc, opts...)
 	}
 }

+ 0 - 19
zrpc/internal/clientinterceptors/timeoutinterceptor_test.go

@@ -49,25 +49,6 @@ func TestTimeoutInterceptor_timeout(t *testing.T) {
 	assert.Nil(t, err)
 }
 
-func TestTimeoutInterceptor_timeoutExpire(t *testing.T) {
-	const timeout = time.Millisecond * 10
-	interceptor := TimeoutInterceptor(timeout)
-	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
-	defer cancel()
-	var wg sync.WaitGroup
-	wg.Add(1)
-	cc := new(grpc.ClientConn)
-	err := interceptor(ctx, "/foo", nil, nil, cc,
-		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
-			opts ...grpc.CallOption) error {
-			defer wg.Done()
-			time.Sleep(time.Millisecond * 50)
-			return nil
-		})
-	wg.Wait()
-	assert.Equal(t, context.DeadlineExceeded, err)
-}
-
 func TestTimeoutInterceptor_panic(t *testing.T) {
 	timeouts := []time.Duration{0, time.Millisecond * 10}
 	for _, timeout := range timeouts {

+ 2 - 0
zrpc/internal/mock/depositserver.go

@@ -2,6 +2,7 @@ package mock
 
 import (
 	"context"
+	"time"
 
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
@@ -16,5 +17,6 @@ func (*DepositServer) Deposit(ctx context.Context, req *DepositRequest) (*Deposi
 		return nil, status.Errorf(codes.InvalidArgument, "cannot deposit %v", req.GetAmount())
 	}
 
+	time.Sleep(time.Duration(req.GetAmount()) * time.Millisecond)
 	return &DepositResponse{Ok: true}, nil
 }