Przeglądaj źródła

Add grpc retry (#1160)

* Add grpc retry

* Update grpc retry

* Add tests

* Fix a bug

* Add api && some tests

* Add comment

* Add double check

* Add server retry quota

* Update optimize code

* Fix bug

* Update optimize code

* Update optimize code

* Fix bug
chenquan 3 lat temu
rodzic
commit
462ddbb145

+ 31 - 0
core/retry/backoff/backoff.go

@@ -0,0 +1,31 @@
+package backoff
+
+import (
+	"math/rand"
+	"time"
+)
+
+type Func func(attempt int) time.Duration
+
+// LinearWithJitter waits a set period of time, allowing for jitter (fractional adjustment).
+func LinearWithJitter(waitBetween time.Duration, jitterFraction float64) Func {
+	r := rand.New(rand.NewSource(time.Now().UnixNano()))
+	return func(attempt int) time.Duration {
+		multiplier := jitterFraction * (r.Float64()*2 - 1)
+		return time.Duration(float64(waitBetween) * (1 + multiplier))
+	}
+}
+
+// Interval it waits for a fixed period of time between calls.
+func Interval(interval time.Duration) Func {
+	return func(attempt int) time.Duration {
+		return interval
+	}
+}
+
+// Exponential produces increasing intervals for each attempt.
+func Exponential(scalar time.Duration) Func {
+	return func(attempt int) time.Duration {
+		return scalar * time.Duration((1<<attempt)>>1)
+	}
+}

+ 17 - 0
core/retry/backoff/backoff_test.go

@@ -0,0 +1,17 @@
+package backoff
+
+import (
+	"github.com/stretchr/testify/assert"
+	"testing"
+	"time"
+)
+
+func TestWaitBetween(t *testing.T) {
+	fn := Interval(time.Second)
+	assert.EqualValues(t, time.Second, fn(1))
+}
+
+func TestExponential(t *testing.T) {
+	fn := Exponential(time.Second)
+	assert.EqualValues(t, time.Second, fn(1))
+}

+ 42 - 0
core/retry/options.go

@@ -0,0 +1,42 @@
+package retry
+
+import (
+	"github.com/tal-tech/go-zero/core/retry/backoff"
+	"google.golang.org/grpc/codes"
+	"time"
+)
+
+// WithDisable disables the retry behaviour on this call, or this interceptor.
+//
+// Its semantically the same to `WithMax`
+func WithDisable() *CallOption {
+	return WithMax(0)
+}
+
+// WithMax sets the maximum number of retries on this call, or this interceptor.
+func WithMax(maxRetries int) *CallOption {
+	return &CallOption{apply: func(options *options) {
+		options.max = maxRetries
+	}}
+}
+
+// WithBackoff sets the `BackoffFunc` used to control time between retries.
+func WithBackoff(backoffFunc backoff.Func) *CallOption {
+	return &CallOption{apply: func(o *options) {
+		o.backoffFunc = backoffFunc
+	}}
+}
+
+// WithCodes Allow code to be retried.
+func WithCodes(retryCodes ...codes.Code) *CallOption {
+	return &CallOption{apply: func(o *options) {
+		o.codes = retryCodes
+	}}
+}
+
+// WithPerRetryTimeout timeout for each retry
+func WithPerRetryTimeout(timeout time.Duration) *CallOption {
+	return &CallOption{apply: func(o *options) {
+		o.perCallTimeout = timeout
+	}}
+}

+ 92 - 0
core/retry/options_test.go

@@ -0,0 +1,92 @@
+package retry
+
+import (
+	"context"
+	"errors"
+	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/core/logx"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/metadata"
+	"google.golang.org/grpc/status"
+	"testing"
+	"time"
+)
+
+func TestRetryWithDisable(t *testing.T) {
+	opt := &options{}
+	assert.EqualValues(t, &options{}, parseRetryCallOptions(opt, WithDisable()))
+}
+
+func TestRetryWithMax(t *testing.T) {
+	n := 5
+	for i := 0; i < n; i++ {
+		opt := &options{}
+		assert.EqualValues(t, &options{max: i}, parseRetryCallOptions(opt, WithMax(i)))
+	}
+}
+
+func TestRetryWithBackoff(t *testing.T) {
+	opt := &options{}
+
+	retryCallOptions := parseRetryCallOptions(opt, WithBackoff(func(attempt int) time.Duration {
+		return time.Millisecond
+	}))
+	assert.EqualValues(t, time.Millisecond, retryCallOptions.backoffFunc(1))
+
+}
+
+func TestRetryWithCodes(t *testing.T) {
+	opt := &options{}
+	c := []codes.Code{codes.Unknown, codes.NotFound}
+	options := parseRetryCallOptions(opt, WithCodes(c...))
+	assert.EqualValues(t, c, options.codes)
+}
+
+func TestRetryWithPerRetryTimeout(t *testing.T) {
+	opt := &options{}
+	options := parseRetryCallOptions(opt, WithPerRetryTimeout(time.Millisecond))
+	assert.EqualValues(t, time.Millisecond, options.perCallTimeout)
+}
+
+func Test_waitRetryBackoff(t *testing.T) {
+
+	opt := &options{perCallTimeout: time.Second, backoffFunc: func(attempt int) time.Duration {
+		return time.Second
+	}}
+	logger := logx.WithContext(context.Background())
+	err := waitRetryBackoff(logger, 1, context.Background(), opt)
+	assert.NoError(t, err)
+	ctx, cancelFunc := context.WithTimeout(context.Background(), time.Millisecond)
+	defer cancelFunc()
+	err = waitRetryBackoff(logger, 1, ctx, opt)
+	assert.ErrorIs(t, err, status.FromContextError(context.DeadlineExceeded).Err())
+}
+
+func Test_isRetriable(t *testing.T) {
+	assert.False(t, isRetriable(status.FromContextError(context.DeadlineExceeded).Err(), &options{codes: DefaultRetriableCodes}))
+	assert.True(t, isRetriable(status.Error(codes.ResourceExhausted, ""), &options{codes: DefaultRetriableCodes}))
+	assert.False(t, isRetriable(errors.New("error"), &options{}))
+}
+
+func Test_perCallContext(t *testing.T) {
+	opt := &options{perCallTimeout: time.Second, includeRetryHeader: true}
+	ctx := metadata.NewIncomingContext(context.Background(), map[string][]string{"1": {"1"}})
+	callContext := perCallContext(ctx, opt, 1)
+	md, ok := metadata.FromOutgoingContext(callContext)
+	assert.True(t, ok)
+	assert.EqualValues(t, metadata.MD{"1": {"1"}, AttemptMetadataKey: {"1"}}, md)
+
+}
+
+func Test_filterCallOptions(t *testing.T) {
+	grpcEmptyCallOpt := &grpc.EmptyCallOption{}
+	retryCallOpt := &CallOption{}
+	options, retryCallOptions := filterCallOptions([]grpc.CallOption{
+		grpcEmptyCallOpt,
+		retryCallOpt,
+	})
+	assert.EqualValues(t, []grpc.CallOption{grpcEmptyCallOpt}, options)
+	assert.EqualValues(t, []*CallOption{retryCallOpt}, retryCallOptions)
+
+}

+ 179 - 0
core/retry/retryinterceptor.go

@@ -0,0 +1,179 @@
+package retry
+
+import (
+	"context"
+	"github.com/tal-tech/go-zero/core/logx"
+	"github.com/tal-tech/go-zero/core/retry/backoff"
+
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/metadata"
+	"google.golang.org/grpc/status"
+
+	"strconv"
+	"time"
+)
+
+const AttemptMetadataKey = "x-retry-attempt"
+
+var (
+	// DefaultRetriableCodes default retry code
+	DefaultRetriableCodes = []codes.Code{codes.ResourceExhausted, codes.Unavailable}
+	// defaultRetryOptions default retry configuration
+	defaultRetryOptions = &options{
+		max:                0, // disabled
+		perCallTimeout:     0, // disabled
+		includeRetryHeader: true,
+		codes:              DefaultRetriableCodes,
+		backoffFunc:        backoff.LinearWithJitter(50*time.Millisecond /*jitter*/, 0.10),
+	}
+)
+
+type (
+	// options retry the configuration
+	options struct {
+		max                int
+		perCallTimeout     time.Duration
+		includeRetryHeader bool
+		codes              []codes.Code
+		backoffFunc        backoff.Func
+	}
+	// CallOption is a grpc.CallOption that is local to grpc retry.
+	CallOption struct {
+		grpc.EmptyCallOption // make sure we implement private after() and before() fields so we don't panic.
+		apply                func(opt *options)
+	}
+)
+
+func waitRetryBackoff(logger logx.Logger, attempt int, ctx context.Context, retryOptions *options) error {
+	var waitTime time.Duration = 0
+	if attempt > 0 {
+		waitTime = retryOptions.backoffFunc(attempt)
+	}
+	if waitTime > 0 {
+		timer := time.NewTimer(waitTime)
+		logger.Infof("grpc retry attempt: %d, backoff for %v", attempt, waitTime)
+		select {
+		case <-ctx.Done():
+			timer.Stop()
+			return status.FromContextError(ctx.Err()).Err()
+		case <-timer.C:
+			// double check
+			err := ctx.Err()
+			if err != nil {
+				return status.FromContextError(err).Err()
+			}
+		}
+	}
+	return nil
+}
+
+func isRetriable(err error, retryOptions *options) bool {
+	errCode := status.Code(err)
+	if isContextError(err) {
+		return false
+	}
+	for _, code := range retryOptions.codes {
+		if code == errCode {
+			return true
+		}
+	}
+	return false
+}
+
+func isContextError(err error) bool {
+	code := status.Code(err)
+	return code == codes.DeadlineExceeded || code == codes.Canceled
+}
+
+func reuseOrNewWithCallOptions(opt *options, retryCallOptions []*CallOption) *options {
+	if len(retryCallOptions) == 0 {
+		return opt
+	}
+	return parseRetryCallOptions(opt, retryCallOptions...)
+}
+
+func parseRetryCallOptions(opt *options, opts ...*CallOption) *options {
+	for _, option := range opts {
+		option.apply(opt)
+	}
+	return opt
+}
+
+func perCallContext(ctx context.Context, callOpts *options, attempt int) context.Context {
+	if attempt > 0 {
+		if callOpts.perCallTimeout != 0 {
+			var cancel context.CancelFunc
+			ctx, cancel = context.WithTimeout(ctx, callOpts.perCallTimeout)
+			_ = cancel
+		}
+		if callOpts.includeRetryHeader {
+			cloneMd := extractIncomingAndClone(ctx)
+			cloneMd.Set(AttemptMetadataKey, strconv.Itoa(attempt))
+			ctx = metadata.NewOutgoingContext(ctx, cloneMd)
+		}
+	}
+
+	return ctx
+}
+
+func extractIncomingAndClone(ctx context.Context) metadata.MD {
+	md, ok := metadata.FromIncomingContext(ctx)
+	if !ok {
+		return metadata.MD{}
+	}
+	// clone
+	return md.Copy()
+}
+
+func filterCallOptions(callOptions []grpc.CallOption) (grpcOptions []grpc.CallOption, retryOptions []*CallOption) {
+	for _, opt := range callOptions {
+		if co, ok := opt.(*CallOption); ok {
+			retryOptions = append(retryOptions, co)
+		} else {
+			grpcOptions = append(grpcOptions, opt)
+		}
+	}
+	return grpcOptions, retryOptions
+}
+
+func Do(ctx context.Context, call func(ctx context.Context, opts ...grpc.CallOption) error, opts ...grpc.CallOption) error {
+	logger := logx.WithContext(ctx)
+	grpcOpts, retryOpts := filterCallOptions(opts)
+	callOpts := reuseOrNewWithCallOptions(defaultRetryOptions, retryOpts)
+
+	if callOpts.max == 0 {
+		return call(ctx, opts...)
+	}
+	var lastErr error
+	for attempt := 0; attempt <= callOpts.max; attempt++ {
+		if err := waitRetryBackoff(logger, attempt, ctx, callOpts); err != nil {
+			return err
+		}
+
+		callCtx := perCallContext(ctx, callOpts, attempt)
+		lastErr = call(callCtx, grpcOpts...)
+
+		if lastErr == nil {
+			return nil
+		}
+		if attempt == 0 {
+			logger.Errorf("grpc call failed, got err: %v", lastErr)
+		} else {
+			logger.Errorf("grpc retry attempt: %d, got err: %v", attempt, lastErr)
+		}
+		if isContextError(lastErr) {
+			if ctx.Err() != nil {
+				logger.Errorf("grpc retry attempt: %d, parent context error: %v", attempt, ctx.Err())
+				return lastErr
+			} else if callOpts.perCallTimeout != 0 {
+				logger.Errorf("grpc retry attempt: %d, context error from retry call", attempt)
+				continue
+			}
+		}
+		if !isRetriable(lastErr, callOpts) {
+			return lastErr
+		}
+	}
+	return lastErr
+}

+ 25 - 0
core/retry/retryinterceptor_test.go

@@ -0,0 +1,25 @@
+package retry
+
+import (
+	"context"
+	"github.com/stretchr/testify/assert"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
+	"testing"
+)
+
+func TestDo(t *testing.T) {
+	n := 4
+	for i := 0; i < n; i++ {
+		count := 0
+		err := Do(context.Background(), func(ctx context.Context, opts ...grpc.CallOption) error {
+			count++
+			return status.Error(codes.ResourceExhausted, "ResourceExhausted")
+
+		}, WithMax(i))
+		assert.Error(t, err)
+		assert.Equal(t, i+1, count)
+	}
+
+}

+ 5 - 0
zrpc/client.go

@@ -14,6 +14,8 @@ var (
 	WithDialOption = internal.WithDialOption
 	// WithTimeout is an alias of internal.WithTimeout.
 	WithTimeout = internal.WithTimeout
+	// WithRetry is an alias of internal.WithRetry.
+	WithRetry = internal.WithRetry
 	// WithUnaryClientInterceptor is an alias of internal.WithUnaryClientInterceptor.
 	WithUnaryClientInterceptor = internal.WithUnaryClientInterceptor
 )
@@ -52,6 +54,9 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
 	if c.Timeout > 0 {
 		opts = append(opts, WithTimeout(time.Duration(c.Timeout)*time.Millisecond))
 	}
+	if c.Retry {
+		opts = append(opts, WithRetry())
+	}
 	opts = append(opts, options...)
 
 	var target string

+ 2 - 0
zrpc/config.go

@@ -18,6 +18,7 @@ type (
 		// setting 0 means no timeout
 		Timeout      int64 `json:",default=2000"`
 		CpuThreshold int64 `json:",default=900,range=[0:1000]"`
+		MaxRetries   int   `json:",range=[0:]"`
 	}
 
 	// A RpcClientConf is a rpc client config.
@@ -27,6 +28,7 @@ type (
 		Target    string          `json:",optional"`
 		App       string          `json:",optional"`
 		Token     string          `json:",optional"`
+		Retry     bool            `json:",optional"` // grpc auto retry
 		Timeout   int64           `json:",default=2000"`
 	}
 )

+ 9 - 0
zrpc/internal/client.go

@@ -31,6 +31,7 @@ type (
 	// A ClientOptions is a client options.
 	ClientOptions struct {
 		Timeout     time.Duration
+		Retry       bool
 		DialOptions []grpc.DialOption
 	}
 
@@ -72,6 +73,7 @@ func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption {
 			clientinterceptors.PrometheusInterceptor,
 			clientinterceptors.BreakerInterceptor,
 			clientinterceptors.TimeoutInterceptor(cliOpts.Timeout),
+			clientinterceptors.RetryInterceptor(cliOpts.Retry),
 		),
 		WithStreamClientInterceptors(
 			clientinterceptors.StreamTracingInterceptor,
@@ -117,6 +119,13 @@ func WithTimeout(timeout time.Duration) ClientOption {
 	}
 }
 
+// WithRetry returns a func to customize a ClientOptions with auto retry.
+func WithRetry() ClientOption {
+	return func(options *ClientOptions) {
+		options.Retry = true
+	}
+}
+
 // WithUnaryClientInterceptor returns a func to customize a ClientOptions with given interceptor.
 func WithUnaryClientInterceptor(interceptor grpc.UnaryClientInterceptor) ClientOption {
 	return func(options *ClientOptions) {

+ 19 - 0
zrpc/internal/clientinterceptors/retryinterceptor.go

@@ -0,0 +1,19 @@
+package clientinterceptors
+
+import (
+	"context"
+	"github.com/tal-tech/go-zero/core/retry"
+	"google.golang.org/grpc"
+)
+
+// RetryInterceptor retry interceptor
+func RetryInterceptor(enable bool) grpc.UnaryClientInterceptor {
+	return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
+		if !enable {
+			return invoker(ctx, method, req, reply, cc, opts...)
+		}
+		return retry.Do(ctx, func(ctx context.Context, callOpts ...grpc.CallOption) error {
+			return invoker(ctx, method, req, reply, cc, callOpts...)
+		}, opts...)
+	}
+}

+ 27 - 0
zrpc/internal/clientinterceptors/retryinterceptor_test.go

@@ -0,0 +1,27 @@
+package clientinterceptors
+
+import (
+	"context"
+	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/core/retry"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
+	"testing"
+)
+
+func TestRetryInterceptor_WithMax(t *testing.T) {
+	n := 4
+	for i := 0; i < n; i++ {
+		count := 0
+		cc := new(grpc.ClientConn)
+		err := RetryInterceptor(true)(context.Background(), "/1", nil, nil, cc,
+			func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
+				count++
+				return status.Error(codes.ResourceExhausted, "ResourceExhausted")
+			}, retry.WithMax(i))
+		assert.Error(t, err)
+		assert.Equal(t, i+1, count)
+	}
+
+}

+ 11 - 2
zrpc/internal/rpcserver.go

@@ -14,7 +14,8 @@ type (
 	ServerOption func(options *rpcServerOptions)
 
 	rpcServerOptions struct {
-		metrics *stat.Metrics
+		metrics    *stat.Metrics
+		MaxRetries int
 	}
 
 	rpcServer struct {
@@ -38,7 +39,7 @@ func NewRpcServer(address string, opts ...ServerOption) Server {
 	}
 
 	return &rpcServer{
-		baseRpcServer: newBaseRpcServer(address, options.metrics),
+		baseRpcServer: newBaseRpcServer(address, &options),
 	}
 }
 
@@ -55,6 +56,7 @@ func (s *rpcServer) Start(register RegisterFn) error {
 
 	unaryInterceptors := []grpc.UnaryServerInterceptor{
 		serverinterceptors.UnaryTracingInterceptor,
+		serverinterceptors.RetryInterceptor(s.maxRetries),
 		serverinterceptors.UnaryCrashInterceptor,
 		serverinterceptors.UnaryStatInterceptor(s.metrics),
 		serverinterceptors.UnaryPrometheusInterceptor,
@@ -87,3 +89,10 @@ func WithMetrics(metrics *stat.Metrics) ServerOption {
 		options.metrics = metrics
 	}
 }
+
+// WithMaxRetries returns a func that sets a max retries to a Server.
+func WithMaxRetries(maxRetries int) ServerOption {
+	return func(options *rpcServerOptions) {
+		options.MaxRetries = maxRetries
+	}
+}

+ 5 - 3
zrpc/internal/server.go

@@ -21,16 +21,18 @@ type (
 	baseRpcServer struct {
 		address            string
 		metrics            *stat.Metrics
+		maxRetries         int
 		options            []grpc.ServerOption
 		streamInterceptors []grpc.StreamServerInterceptor
 		unaryInterceptors  []grpc.UnaryServerInterceptor
 	}
 )
 
-func newBaseRpcServer(address string, metrics *stat.Metrics) *baseRpcServer {
+func newBaseRpcServer(address string, rpcServerOpts *rpcServerOptions) *baseRpcServer {
 	return &baseRpcServer{
-		address: address,
-		metrics: metrics,
+		address:    address,
+		metrics:    rpcServerOpts.metrics,
+		maxRetries: rpcServerOpts.MaxRetries,
 	}
 }
 

+ 3 - 3
zrpc/internal/server_test.go

@@ -11,7 +11,7 @@ import (
 
 func TestBaseRpcServer_AddOptions(t *testing.T) {
 	metrics := stat.NewMetrics("foo")
-	server := newBaseRpcServer("foo", metrics)
+	server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics})
 	server.SetName("bar")
 	var opt grpc.EmptyServerOption
 	server.AddOptions(opt)
@@ -20,7 +20,7 @@ func TestBaseRpcServer_AddOptions(t *testing.T) {
 
 func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) {
 	metrics := stat.NewMetrics("foo")
-	server := newBaseRpcServer("foo", metrics)
+	server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics})
 	server.SetName("bar")
 	var vals []int
 	f := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
@@ -36,7 +36,7 @@ func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) {
 
 func TestBaseRpcServer_AddUnaryInterceptors(t *testing.T) {
 	metrics := stat.NewMetrics("foo")
-	server := newBaseRpcServer("foo", metrics)
+	server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics})
 	server.SetName("bar")
 	var vals []int
 	f := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (

+ 33 - 0
zrpc/internal/serverinterceptors/retryinterceptor.go

@@ -0,0 +1,33 @@
+package serverinterceptors
+
+import (
+	"context"
+	"github.com/tal-tech/go-zero/core/logx"
+	"github.com/tal-tech/go-zero/core/retry"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/metadata"
+	"google.golang.org/grpc/status"
+	"strconv"
+)
+
+func RetryInterceptor(maxAttempt int) grpc.UnaryServerInterceptor {
+	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
+		var md metadata.MD
+		requestMd, ok := metadata.FromIncomingContext(ctx)
+		if ok {
+			md = requestMd.Copy()
+			attemptMd := md.Get(retry.AttemptMetadataKey)
+			if len(attemptMd) != 0 && attemptMd[0] != "" {
+				if attempt, err := strconv.Atoi(attemptMd[0]); err == nil {
+					if attempt > maxAttempt {
+						logx.WithContext(ctx).Errorf("retries exceeded:%d, max retries:%d", attempt, maxAttempt)
+						return nil, status.Error(codes.FailedPrecondition, "Retries exceeded")
+					}
+				}
+			}
+		}
+
+		return handler(ctx, req)
+	}
+}

+ 40 - 0
zrpc/internal/serverinterceptors/retryinterceptor_test.go

@@ -0,0 +1,40 @@
+package serverinterceptors
+
+import (
+	"context"
+	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/core/retry"
+	"google.golang.org/grpc/metadata"
+	"testing"
+)
+
+func TestRetryInterceptor(t *testing.T) {
+	t.Run("retries exceeded", func(t *testing.T) {
+		interceptor := RetryInterceptor(2)
+		ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{retry.AttemptMetadataKey: "3"}))
+		resp, err := interceptor(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
+			return nil, nil
+		})
+		assert.Error(t, err)
+		assert.Nil(t, resp)
+	})
+
+	t.Run("reasonable retries", func(t *testing.T) {
+		interceptor := RetryInterceptor(2)
+		ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{retry.AttemptMetadataKey: "2"}))
+		resp, err := interceptor(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
+			return nil, nil
+		})
+		assert.NoError(t, err)
+		assert.Nil(t, resp)
+	})
+	t.Run("no retries", func(t *testing.T) {
+		interceptor := RetryInterceptor(0)
+		resp, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
+			return nil, nil
+		})
+		assert.NoError(t, err)
+		assert.Nil(t, resp)
+	})
+
+}

+ 4 - 2
zrpc/server.go

@@ -38,13 +38,15 @@ func NewServer(c RpcServerConf, register internal.RegisterFn) (*RpcServer, error
 
 	var server internal.Server
 	metrics := stat.NewMetrics(c.ListenOn)
+	serverOptions := []internal.ServerOption{internal.WithMetrics(metrics), internal.WithMaxRetries(c.MaxRetries)}
+
 	if c.HasEtcd() {
-		server, err = internal.NewRpcPubServer(c.Etcd.Hosts, c.Etcd.Key, c.ListenOn, internal.WithMetrics(metrics))
+		server, err = internal.NewRpcPubServer(c.Etcd.Hosts, c.Etcd.Key, c.ListenOn, serverOptions...)
 		if err != nil {
 			return nil, err
 		}
 	} else {
-		server = internal.NewRpcServer(c.ListenOn, internal.WithMetrics(metrics))
+		server = internal.NewRpcServer(c.ListenOn, serverOptions...)
 	}
 
 	server.SetName(c.Name)