浏览代码

chore: avoid deadlock after stopping TimingWheel (#1768)

Kevin Wan 3 年之前
父节点
当前提交
8bc34defc4
共有 2 个文件被更改,包括 49 次插入13 次删除
  1. 37 11
      core/collection/timingwheel.go
  2. 12 2
      core/collection/timingwheel_test.go

+ 37 - 11
core/collection/timingwheel.go

@@ -2,6 +2,7 @@ package collection
 
 import (
 	"container/list"
+	"errors"
 	"fmt"
 	"time"
 
@@ -12,6 +13,11 @@ import (
 
 const drainWorkers = 8
 
+var (
+	ErrClosed   = errors.New("TimingWheel is closed already")
+	ErrArgument = errors.New("incorrect task argument")
+)
+
 type (
 	// Execute defines the method to execute the task.
 	Execute func(key, value interface{})
@@ -89,43 +95,63 @@ func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execu
 }
 
 // Drain drains all items and executes them.
-func (tw *TimingWheel) Drain(fn func(key, value interface{})) {
-	tw.drainChannel <- fn
+func (tw *TimingWheel) Drain(fn func(key, value interface{})) error {
+	select {
+	case tw.drainChannel <- fn:
+		return nil
+	case <-tw.stopChannel:
+		return ErrClosed
+	}
 }
 
 // MoveTimer moves the task with the given key to the given delay.
-func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) {
+func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error {
 	if delay <= 0 || key == nil {
-		return
+		return ErrArgument
 	}
 
-	tw.moveChannel <- baseEntry{
+	select {
+	case tw.moveChannel <- baseEntry{
 		delay: delay,
 		key:   key,
+	}:
+		return nil
+	case <-tw.stopChannel:
+		return ErrClosed
 	}
 }
 
 // RemoveTimer removes the task with the given key.
-func (tw *TimingWheel) RemoveTimer(key interface{}) {
+func (tw *TimingWheel) RemoveTimer(key interface{}) error {
 	if key == nil {
-		return
+		return ErrArgument
 	}
 
-	tw.removeChannel <- key
+	select {
+	case tw.removeChannel <- key:
+		return nil
+	case <-tw.stopChannel:
+		return ErrClosed
+	}
 }
 
 // SetTimer sets the task value with the given key to the delay.
-func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) {
+func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error {
 	if delay <= 0 || key == nil {
-		return
+		return ErrArgument
 	}
 
-	tw.setChannel <- timingEntry{
+	select {
+	case tw.setChannel <- timingEntry{
 		baseEntry: baseEntry{
 			delay: delay,
 			key:   key,
 		},
 		value: value,
+	}:
+		return nil
+	case <-tw.stopChannel:
+		return ErrClosed
 	}
 }
 

+ 12 - 2
core/collection/timingwheel_test.go

@@ -28,7 +28,6 @@ func TestTimingWheel_Drain(t *testing.T) {
 	ticker := timex.NewFakeTicker()
 	tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
 	}, ticker)
-	defer tw.Stop()
 	tw.SetTimer("first", 3, testStep*4)
 	tw.SetTimer("second", 5, testStep*7)
 	tw.SetTimer("third", 7, testStep*7)
@@ -56,6 +55,8 @@ func TestTimingWheel_Drain(t *testing.T) {
 	})
 	time.Sleep(time.Millisecond * 100)
 	assert.Equal(t, 0, count)
+	tw.Stop()
+	assert.Equal(t, ErrClosed, tw.Drain(func(key, value interface{}) {}))
 }
 
 func TestTimingWheel_SetTimerSoon(t *testing.T) {
@@ -102,6 +103,13 @@ func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
 	})
 }
 
+func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
+	ticker := timex.NewFakeTicker()
+	tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
+	tw.Stop()
+	assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep))
+}
+
 func TestTimingWheel_MoveTimer(t *testing.T) {
 	run := syncx.NewAtomicBool()
 	ticker := timex.NewFakeTicker()
@@ -111,7 +119,6 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
 		assert.Equal(t, 3, v.(int))
 		ticker.Done()
 	}, ticker)
-	defer tw.Stop()
 	tw.SetTimer("any", 3, testStep*4)
 	tw.MoveTimer("any", testStep*7)
 	tw.MoveTimer("any", -testStep)
@@ -125,6 +132,8 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
 	}
 	assert.Nil(t, ticker.Wait(waitTime))
 	assert.True(t, run.True())
+	tw.Stop()
+	assert.Equal(t, ErrClosed, tw.MoveTimer("any", time.Millisecond))
 }
 
 func TestTimingWheel_MoveTimerSoon(t *testing.T) {
@@ -175,6 +184,7 @@ func TestTimingWheel_RemoveTimer(t *testing.T) {
 		ticker.Tick()
 	}
 	tw.Stop()
+	assert.Equal(t, ErrClosed, tw.RemoveTimer("any"))
 }
 
 func TestTimingWheel_SetTimer(t *testing.T) {