Selaa lähdekoodia

Fix context error in grpc (#962)

* Fix context error in rpc

* Add a test case

* Optimize judgment conditions

* Add customized breaker errors for the client and server

* Update method signature

* Delete customized breaker errors

* Delete the wrong test case
chenquan 3 vuotta sitten
vanhempi
sitoutus
dfb3cb510a

+ 0 - 13
zrpc/internal/codes/accept.go

@@ -1,8 +1,6 @@
 package codes
 
 import (
-	"context"
-
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 )
@@ -12,17 +10,6 @@ func Acceptable(err error) bool {
 	switch status.Code(err) {
 	case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss:
 		return false
-	case codes.Unknown:
-		return acceptableUnknown(err)
-	default:
-		return true
-	}
-}
-
-func acceptableUnknown(err error) bool {
-	switch err {
-	case context.DeadlineExceeded:
-		return false
 	default:
 		return true
 	}

+ 1 - 0
zrpc/internal/codes/accept_test.go

@@ -9,6 +9,7 @@ import (
 )
 
 func TestAccept(t *testing.T) {
+
 	tests := []struct {
 		name   string
 		err    error

+ 10 - 1
zrpc/internal/serverinterceptors/timeoutinterceptor.go

@@ -3,6 +3,8 @@ package serverinterceptors
 import (
 	"context"
 	"fmt"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
 	"runtime/debug"
 	"strings"
 	"sync"
@@ -46,7 +48,14 @@ func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor
 			defer lock.Unlock()
 			return resp, err
 		case <-ctx.Done():
-			return nil, ctx.Err()
+			err := ctx.Err()
+
+			if err == context.Canceled {
+				err = status.Error(codes.Canceled, err.Error())
+			} else if err == context.DeadlineExceeded {
+				err = status.Error(codes.DeadlineExceeded, err.Error())
+			}
+			return nil, err
 		}
 	}
 }

+ 22 - 1
zrpc/internal/serverinterceptors/timeoutinterceptor_test.go

@@ -2,6 +2,8 @@ package serverinterceptors
 
 import (
 	"context"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
 	"sync"
 	"testing"
 	"time"
@@ -66,5 +68,24 @@ func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) {
 		return nil, nil
 	})
 	wg.Wait()
-	assert.Equal(t, context.DeadlineExceeded, err)
+	assert.EqualValues(t, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()), err)
+}
+func TestUnaryTimeoutInterceptor_cancel(t *testing.T) {
+	const timeout = time.Minute * 10
+	interceptor := UnaryTimeoutInterceptor(timeout)
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel()
+
+	var wg sync.WaitGroup
+	wg.Add(1)
+	_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
+		FullMethod: "/",
+	}, func(ctx context.Context, req interface{}) (interface{}, error) {
+		defer wg.Done()
+		time.Sleep(time.Millisecond * 50)
+		return nil, nil
+	})
+
+	wg.Wait()
+	assert.EqualValues(t, status.Error(codes.Canceled, context.Canceled.Error()), err)
 }