Sfoglia il codice sorgente

update: expand the retry method to support timeout and interval control (#3283)

Xiaoju Jiang 1 anno fa
parent
commit
8d48e34eed
2 ha cambiato i file con 160 aggiunte e 16 eliminazioni
  1. 77 7
      core/fx/retry.go
  2. 83 9
      core/fx/retry_test.go

+ 77 - 7
core/fx/retry.go

@@ -1,31 +1,87 @@
 package fx
 
-import "github.com/zeromicro/go-zero/core/errorx"
+import (
+	"context"
+	"errors"
+	"time"
+
+	"github.com/zeromicro/go-zero/core/errorx"
+)
 
 const defaultRetryTimes = 3
 
+var errTimeout = errors.New("retry timeout")
+
 type (
 	// RetryOption defines the method to customize DoWithRetry.
 	RetryOption func(*retryOptions)
 
 	retryOptions struct {
-		times int
+		times    int
+		interval time.Duration
+		timeout  time.Duration
 	}
 )
 
 // DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
+// Note that if the fn function accesses global variables outside the function and performs modification operations,
+// it is best to lock them, otherwise there may be data race issues
 func DoWithRetry(fn func() error, opts ...RetryOption) error {
+	return retry(fn, opts...)
+}
+
+// DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times.
+// fn retryCount indicates the current number of retries,starting from 0
+// Note that if the fn function accesses global variables outside the function and performs modification operations,
+// it is best to lock them, otherwise there may be data race issues
+func DoWithRetryCtx(fn func(ctx context.Context, retryCount int) error, opts ...RetryOption) error {
+	return retry(fn, opts...)
+}
+
+func retry(fn interface{}, opts ...RetryOption) error {
 	options := newRetryOptions()
 	for _, opt := range opts {
 		opt(options)
 	}
 
+	sign := make(chan error, 1)
 	var berr errorx.BatchError
+	var cancelFunc context.CancelFunc
+	ctx := context.Background()
+	if options.timeout > 0 {
+		ctx, cancelFunc = context.WithTimeout(ctx, options.timeout)
+		defer cancelFunc()
+	}
+
 	for i := 0; i < options.times; i++ {
-		if err := fn(); err != nil {
-			berr.Add(err)
-		} else {
-			return nil
+		go func(retryCount int) {
+			switch f := fn.(type) {
+			case func() error:
+				sign <- f()
+			case func(ctx context.Context, retryCount int) error:
+				sign <- f(ctx, retryCount)
+			}
+		}(i)
+
+		select {
+		case err := <-sign:
+			if err != nil {
+				berr.Add(err)
+			} else {
+				return nil
+			}
+		case <-ctx.Done():
+			berr.Add(errTimeout)
+			return berr.Err()
+		}
+
+		if options.interval > 0 {
+			select {
+			case <-ctx.Done():
+				berr.Add(errTimeout)
+				return berr.Err()
+			case <-time.After(options.interval):
+			}
 		}
 	}
 
@@ -39,8 +95,22 @@ func WithRetry(times int) RetryOption {
 	}
 }
 
+func WithInterval(interval time.Duration) RetryOption {
+	return func(options *retryOptions) {
+		options.interval = interval
+	}
+}
+
+func WithTimeout(timeout time.Duration) RetryOption {
+	return func(options *retryOptions) {
+		options.timeout = timeout
+	}
+}
+
 func newRetryOptions() *retryOptions {
 	return &retryOptions{
-		times: defaultRetryTimes,
+		times:    defaultRetryTimes,
+		interval: 0,
+		timeout:  0,
 	}
 }

+ 83 - 9
core/fx/retry_test.go

@@ -1,8 +1,10 @@
 package fx
 
 import (
+	"context"
 	"errors"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 )
@@ -12,31 +14,103 @@ func TestRetry(t *testing.T) {
 		return errors.New("any")
 	}))
 
-	var times int
+	times1 := 0
 	assert.Nil(t, DoWithRetry(func() error {
-		times++
-		if times == defaultRetryTimes {
+		times1++
+		if times1 == defaultRetryTimes {
 			return nil
 		}
 		return errors.New("any")
 	}))
 
-	times = 0
+	times2 := 0
 	assert.NotNil(t, DoWithRetry(func() error {
-		times++
-		if times == defaultRetryTimes+1 {
+		times2++
+		if times2 == defaultRetryTimes+1 {
 			return nil
 		}
 		return errors.New("any")
 	}))
 
 	total := 2 * defaultRetryTimes
-	times = 0
+	times3 := 0
 	assert.Nil(t, DoWithRetry(func() error {
-		times++
-		if times == total {
+		times3++
+		if times3 == total {
 			return nil
 		}
 		return errors.New("any")
 	}, WithRetry(total)))
 }
+
+func TestRetryWithTimeout(t *testing.T) {
+	assert.Nil(t, DoWithRetry(func() error {
+		return nil
+	}, WithTimeout(time.Second*10)))
+
+	times1 := 0
+	assert.Nil(t, DoWithRetry(func() error {
+		times1++
+		if times1 == 1 {
+			return errors.New("any ")
+		}
+		time.Sleep(time.Second * 3)
+		return nil
+	}, WithTimeout(time.Second*5)))
+
+	total := defaultRetryTimes
+	times2 := 0
+	assert.Nil(t, DoWithRetry(func() error {
+		times2++
+		if times2 == total {
+			return nil
+		}
+		time.Sleep(time.Second)
+		return errors.New("any")
+	}, WithTimeout(time.Second*(time.Duration(total)+2))))
+
+	assert.NotNil(t, DoWithRetry(func() error {
+		return errors.New("any")
+	}, WithTimeout(time.Second*5)))
+}
+
+func TestRetryWithInterval(t *testing.T) {
+	times1 := 0
+	assert.NotNil(t, DoWithRetry(func() error {
+		times1++
+		if times1 == 1 {
+			return errors.New("any")
+		}
+		time.Sleep(time.Second * 3)
+		return nil
+	}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
+
+	times2 := 0
+	assert.NotNil(t, DoWithRetry(func() error {
+		times2++
+		if times2 == 2 {
+			return nil
+		}
+		time.Sleep(time.Second * 3)
+		return errors.New("any ")
+	}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
+
+}
+
+func TestRetryCtx(t *testing.T) {
+	assert.NotNil(t, DoWithRetryCtx(func(ctx context.Context, retryCount int) error {
+		if retryCount == 0 {
+			return errors.New("any")
+		}
+		time.Sleep(time.Second * 3)
+		return nil
+	}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
+
+	assert.NotNil(t, DoWithRetryCtx(func(ctx context.Context, retryCount int) error {
+		if retryCount == 1 {
+			return nil
+		}
+		time.Sleep(time.Second * 3)
+		return errors.New("any ")
+	}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
+}