Kevin Wan 3 rokov pred
rodič
commit
0ee7654407

+ 1 - 1
zrpc/internal/client.go

@@ -69,8 +69,8 @@ func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption {
 		WithUnaryClientInterceptors(
 			clientinterceptors.TracingInterceptor,
 			clientinterceptors.DurationInterceptor,
-			clientinterceptors.BreakerInterceptor,
 			clientinterceptors.PrometheusInterceptor,
+			clientinterceptors.BreakerInterceptor,
 			clientinterceptors.TimeoutInterceptor(cliOpts.Timeout),
 		),
 	}

+ 13 - 0
zrpc/internal/codes/accept.go

@@ -1,6 +1,8 @@
 package codes
 
 import (
+	"context"
+
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 )
@@ -10,6 +12,17 @@ func Acceptable(err error) bool {
 	switch status.Code(err) {
 	case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss:
 		return false
+	case codes.Unknown:
+		return acceptableUnknown(err)
+	default:
+		return true
+	}
+}
+
+func acceptableUnknown(err error) bool {
+	switch err {
+	case context.DeadlineExceeded:
+		return false
 	default:
 		return true
 	}

+ 2 - 0
zrpc/internal/rpcserver.go

@@ -58,10 +58,12 @@ func (s *rpcServer) Start(register RegisterFn) error {
 		serverinterceptors.UnaryCrashInterceptor(),
 		serverinterceptors.UnaryStatInterceptor(s.metrics),
 		serverinterceptors.UnaryPrometheusInterceptor(),
+		serverinterceptors.UnaryBreakerInterceptor(),
 	}
 	unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...)
 	streamInterceptors := []grpc.StreamServerInterceptor{
 		serverinterceptors.StreamCrashInterceptor,
+		serverinterceptors.StreamBreakerInterceptor,
 	}
 	streamInterceptors = append(streamInterceptors, s.streamInterceptors...)
 	options := append(s.options, WithUnaryServerInterceptors(unaryInterceptors...),

+ 33 - 0
zrpc/internal/serverinterceptors/breakerinterceptor.go

@@ -0,0 +1,33 @@
+package serverinterceptors
+
+import (
+	"context"
+
+	"github.com/tal-tech/go-zero/core/breaker"
+	"github.com/tal-tech/go-zero/zrpc/internal/codes"
+	"google.golang.org/grpc"
+)
+
+// StreamBreakerInterceptor is an interceptor that acts as a circuit breaker.
+func StreamBreakerInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
+	handler grpc.StreamHandler) (err error) {
+	breakerName := info.FullMethod
+	return breaker.DoWithAcceptable(breakerName, func() error {
+		return handler(srv, stream)
+	}, codes.Acceptable)
+}
+
+// UnaryBreakerInterceptor is an interceptor that acts as a circuit breaker.
+func UnaryBreakerInterceptor() grpc.UnaryServerInterceptor {
+	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
+		handler grpc.UnaryHandler) (resp interface{}, err error) {
+		breakerName := info.FullMethod
+		err = breaker.DoWithAcceptable(breakerName, func() error {
+			var err error
+			resp, err = handler(ctx, req)
+			return err
+		}, codes.Acceptable)
+
+		return resp, err
+	}
+}

+ 31 - 0
zrpc/internal/serverinterceptors/breakerinterceptor_test.go

@@ -0,0 +1,31 @@
+package serverinterceptors
+
+import (
+	"context"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
+)
+
+func TestStreamBreakerInterceptor(t *testing.T) {
+	err := StreamBreakerInterceptor(nil, nil, &grpc.StreamServerInfo{
+		FullMethod: "any",
+	}, func(
+		srv interface{}, stream grpc.ServerStream) error {
+		return status.New(codes.DeadlineExceeded, "any").Err()
+	})
+	assert.NotNil(t, err)
+}
+
+func TestUnaryBreakerInterceptor(t *testing.T) {
+	interceptor := UnaryBreakerInterceptor()
+	_, err := interceptor(nil, nil, &grpc.UnaryServerInfo{
+		FullMethod: "any",
+	}, func(ctx context.Context, req interface{}) (interface{}, error) {
+		return nil, status.New(codes.DeadlineExceeded, "any").Err()
+	})
+	assert.NotNil(t, err)
+}