kevin 4 gadi atpakaļ
vecāks
revīzija
295c8d2934
2 mainītis faili ar 27 papildinājumiem un 1 dzēšanām
  1. 2 1
      core/collection/timingwheel.go
  2. 25 0
      core/collection/timingwheel_test.go

+ 2 - 1
core/collection/timingwheel.go

@@ -204,6 +204,7 @@ func (tw *TimingWheel) removeTask(key interface{}) {
 
 	timer := val.(*positionEntry)
 	timer.item.removed = true
+	tw.timers.Del(key)
 }
 
 func (tw *TimingWheel) run() {
@@ -248,7 +249,6 @@ func (tw *TimingWheel) scanAndRunTasks(l *list.List) {
 		if task.removed {
 			next := e.Next()
 			l.Remove(e)
-			tw.timers.Del(task.key)
 			e = next
 			continue
 		} else if task.circle > 0 {
@@ -301,6 +301,7 @@ func (tw *TimingWheel) setTask(task *timingEntry) {
 func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) {
 	if val, ok := tw.timers.Get(task.key); ok {
 		timer := val.(*positionEntry)
+		timer.item = task
 		timer.pos = pos
 	} else {
 		tw.timers.Set(task.key, &positionEntry{

+ 25 - 0
core/collection/timingwheel_test.go

@@ -594,6 +594,31 @@ func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
 	}
 }
 
+func TestMoveAndRemoveTask(t *testing.T) {
+	ticker := timex.NewFakeTicker()
+	tick := func(v int) {
+		for i := 0; i < v; i++ {
+			ticker.Tick()
+		}
+	}
+	var keys []int
+	tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
+		assert.Equal(t, "any", k)
+		assert.Equal(t, 3, v.(int))
+		keys = append(keys, v.(int))
+		ticker.Done()
+	}, ticker)
+	defer tw.Stop()
+	tw.SetTimer("any", 3, testStep*8)
+	tick(6)
+	tw.MoveTimer("any", testStep*7)
+	tick(3)
+	tw.RemoveTimer("any")
+	tick(30)
+	time.Sleep(time.Millisecond)
+	assert.Equal(t, 0, len(keys))
+}
+
 func BenchmarkTimingWheel(b *testing.B) {
 	b.ReportAllocs()