Kevin Wan пре 1 година
родитељ
комит
2f2ddd373b
2 измењених фајлова са 36 додато и 38 уклоњено
  1. 20 22
      core/fx/retry.go
  2. 16 16
      core/fx/retry_test.go

+ 20 - 22
core/fx/retry.go

@@ -24,27 +24,33 @@ type (
 )
 
 // DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
-// Note that if the fn function accesses global variables outside the function and performs modification operations,
-// it is best to lock them, otherwise there may be data race issues
+// Note that if the fn function accesses global variables outside the function
+// and performs modification operations, it is best to lock them,
+// otherwise there may be data race issues
 func DoWithRetry(fn func() error, opts ...RetryOption) error {
-	return retry(fn, opts...)
+	return retry(func(errChan chan error, retryCount int) {
+		errChan <- fn()
+	}, opts...)
 }
 
 // DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times.
-// fn retryCount indicates the current number of retries,starting from 0
-// Note that if the fn function accesses global variables outside the function and performs modification operations,
-// it is best to lock them, otherwise there may be data race issues
-func DoWithRetryCtx(fn func(ctx context.Context, retryCount int) error, opts ...RetryOption) error {
-	return retry(fn, opts...)
+// fn retryCount indicates the current number of retries, starting from 0
+// Note that if the fn function accesses global variables outside the function
+// and performs modification operations, it is best to lock them,
+// otherwise there may be data race issues
+func DoWithRetryCtx(ctx context.Context, fn func(ctx context.Context, retryCount int) error,
+	opts ...RetryOption) error {
+	return retry(func(errChan chan error, retryCount int) {
+		errChan <- fn(ctx, retryCount)
+	}, opts...)
 }
 
-func retry(fn interface{}, opts ...RetryOption) error {
+func retry(fn func(errChan chan error, retryCount int), opts ...RetryOption) error {
 	options := newRetryOptions()
 	for _, opt := range opts {
 		opt(options)
 	}
 
-	sign := make(chan error, 1)
 	var berr errorx.BatchError
 	var cancelFunc context.CancelFunc
 	ctx := context.Background()
@@ -53,18 +59,12 @@ func retry(fn interface{}, opts ...RetryOption) error {
 		defer cancelFunc()
 	}
 
+	errChan := make(chan error, 1)
 	for i := 0; i < options.times; i++ {
-		go func(retryCount int) {
-			switch f := fn.(type) {
-			case func() error:
-				sign <- f()
-			case func(ctx context.Context, retryCount int) error:
-				sign <- f(ctx, retryCount)
-			}
-		}(i)
+		go fn(errChan, i)
 
 		select {
-		case err := <-sign:
+		case err := <-errChan:
 			if err != nil {
 				berr.Add(err)
 			} else {
@@ -109,8 +109,6 @@ func WithTimeout(timeout time.Duration) RetryOption {
 
 func newRetryOptions() *retryOptions {
 	return &retryOptions{
-		times:    defaultRetryTimes,
-		interval: 0,
-		timeout:  0,
+		times: defaultRetryTimes,
 	}
 }

+ 16 - 16
core/fx/retry_test.go

@@ -46,7 +46,7 @@ func TestRetry(t *testing.T) {
 func TestRetryWithTimeout(t *testing.T) {
 	assert.Nil(t, DoWithRetry(func() error {
 		return nil
-	}, WithTimeout(time.Second*10)))
+	}, WithTimeout(time.Millisecond*500)))
 
 	times1 := 0
 	assert.Nil(t, DoWithRetry(func() error {
@@ -54,9 +54,9 @@ func TestRetryWithTimeout(t *testing.T) {
 		if times1 == 1 {
 			return errors.New("any ")
 		}
-		time.Sleep(time.Second * 3)
+		time.Sleep(time.Millisecond * 150)
 		return nil
-	}, WithTimeout(time.Second*5)))
+	}, WithTimeout(time.Millisecond*250)))
 
 	total := defaultRetryTimes
 	times2 := 0
@@ -65,13 +65,13 @@ func TestRetryWithTimeout(t *testing.T) {
 		if times2 == total {
 			return nil
 		}
-		time.Sleep(time.Second)
+		time.Sleep(time.Millisecond * 50)
 		return errors.New("any")
-	}, WithTimeout(time.Second*(time.Duration(total)+2))))
+	}, WithTimeout(time.Millisecond*50*(time.Duration(total)+2))))
 
 	assert.NotNil(t, DoWithRetry(func() error {
 		return errors.New("any")
-	}, WithTimeout(time.Second*5)))
+	}, WithTimeout(time.Millisecond*250)))
 }
 
 func TestRetryWithInterval(t *testing.T) {
@@ -81,9 +81,9 @@ func TestRetryWithInterval(t *testing.T) {
 		if times1 == 1 {
 			return errors.New("any")
 		}
-		time.Sleep(time.Second * 3)
+		time.Sleep(time.Millisecond * 150)
 		return nil
-	}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
+	}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
 
 	times2 := 0
 	assert.NotNil(t, DoWithRetry(func() error {
@@ -91,26 +91,26 @@ func TestRetryWithInterval(t *testing.T) {
 		if times2 == 2 {
 			return nil
 		}
-		time.Sleep(time.Second * 3)
+		time.Sleep(time.Millisecond * 150)
 		return errors.New("any ")
-	}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
+	}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
 
 }
 
 func TestRetryCtx(t *testing.T) {
-	assert.NotNil(t, DoWithRetryCtx(func(ctx context.Context, retryCount int) error {
+	assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
 		if retryCount == 0 {
 			return errors.New("any")
 		}
-		time.Sleep(time.Second * 3)
+		time.Sleep(time.Millisecond * 150)
 		return nil
-	}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
+	}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
 
-	assert.NotNil(t, DoWithRetryCtx(func(ctx context.Context, retryCount int) error {
+	assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
 		if retryCount == 1 {
 			return nil
 		}
-		time.Sleep(time.Second * 3)
+		time.Sleep(time.Millisecond * 150)
 		return errors.New("any ")
-	}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
+	}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
 }