timingwheel.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. package collection
  2. import (
  3. "container/list"
  4. "errors"
  5. "fmt"
  6. "time"
  7. "github.com/wuntsong-org/go-zero-plus/core/lang"
  8. "github.com/wuntsong-org/go-zero-plus/core/threading"
  9. "github.com/wuntsong-org/go-zero-plus/core/timex"
  10. )
  11. const drainWorkers = 8
  12. var (
  13. ErrClosed = errors.New("TimingWheel is closed already")
  14. ErrArgument = errors.New("incorrect task argument")
  15. )
  16. type (
  17. // Execute defines the method to execute the task.
  18. Execute func(key, value any)
  19. // A TimingWheel is a timing wheel object to schedule tasks.
  20. TimingWheel struct {
  21. interval time.Duration
  22. ticker timex.Ticker
  23. slots []*list.List
  24. timers *SafeMap
  25. tickedPos int
  26. numSlots int
  27. execute Execute
  28. setChannel chan timingEntry
  29. moveChannel chan baseEntry
  30. removeChannel chan any
  31. drainChannel chan func(key, value any)
  32. stopChannel chan lang.PlaceholderType
  33. }
  34. timingEntry struct {
  35. baseEntry
  36. value any
  37. circle int
  38. diff int
  39. removed bool
  40. }
  41. baseEntry struct {
  42. delay time.Duration
  43. key any
  44. }
  45. positionEntry struct {
  46. pos int
  47. item *timingEntry
  48. }
  49. timingTask struct {
  50. key any
  51. value any
  52. }
  53. )
  54. // NewTimingWheel returns a TimingWheel.
  55. func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) {
  56. if interval <= 0 || numSlots <= 0 || execute == nil {
  57. return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p",
  58. interval, numSlots, execute)
  59. }
  60. return NewTimingWheelWithTicker(interval, numSlots, execute, timex.NewTicker(interval))
  61. }
  62. // NewTimingWheelWithTicker returns a TimingWheel with the given ticker.
  63. func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Execute,
  64. ticker timex.Ticker) (*TimingWheel, error) {
  65. tw := &TimingWheel{
  66. interval: interval,
  67. ticker: ticker,
  68. slots: make([]*list.List, numSlots),
  69. timers: NewSafeMap(),
  70. tickedPos: numSlots - 1, // at previous virtual circle
  71. execute: execute,
  72. numSlots: numSlots,
  73. setChannel: make(chan timingEntry),
  74. moveChannel: make(chan baseEntry),
  75. removeChannel: make(chan any),
  76. drainChannel: make(chan func(key, value any)),
  77. stopChannel: make(chan lang.PlaceholderType),
  78. }
  79. tw.initSlots()
  80. go tw.run()
  81. return tw, nil
  82. }
  83. // Drain drains all items and executes them.
  84. func (tw *TimingWheel) Drain(fn func(key, value any)) error {
  85. select {
  86. case tw.drainChannel <- fn:
  87. return nil
  88. case <-tw.stopChannel:
  89. return ErrClosed
  90. }
  91. }
  92. // MoveTimer moves the task with the given key to the given delay.
  93. func (tw *TimingWheel) MoveTimer(key any, delay time.Duration) error {
  94. if delay <= 0 || key == nil {
  95. return ErrArgument
  96. }
  97. select {
  98. case tw.moveChannel <- baseEntry{
  99. delay: delay,
  100. key: key,
  101. }:
  102. return nil
  103. case <-tw.stopChannel:
  104. return ErrClosed
  105. }
  106. }
  107. // RemoveTimer removes the task with the given key.
  108. func (tw *TimingWheel) RemoveTimer(key any) error {
  109. if key == nil {
  110. return ErrArgument
  111. }
  112. select {
  113. case tw.removeChannel <- key:
  114. return nil
  115. case <-tw.stopChannel:
  116. return ErrClosed
  117. }
  118. }
  119. // SetTimer sets the task value with the given key to the delay.
  120. func (tw *TimingWheel) SetTimer(key, value any, delay time.Duration) error {
  121. if delay <= 0 || key == nil {
  122. return ErrArgument
  123. }
  124. select {
  125. case tw.setChannel <- timingEntry{
  126. baseEntry: baseEntry{
  127. delay: delay,
  128. key: key,
  129. },
  130. value: value,
  131. }:
  132. return nil
  133. case <-tw.stopChannel:
  134. return ErrClosed
  135. }
  136. }
  137. // Stop stops tw. No more actions after stopping a TimingWheel.
  138. func (tw *TimingWheel) Stop() {
  139. close(tw.stopChannel)
  140. }
  141. func (tw *TimingWheel) drainAll(fn func(key, value any)) {
  142. runner := threading.NewTaskRunner(drainWorkers)
  143. for _, slot := range tw.slots {
  144. for e := slot.Front(); e != nil; {
  145. task := e.Value.(*timingEntry)
  146. next := e.Next()
  147. slot.Remove(e)
  148. e = next
  149. if !task.removed {
  150. runner.Schedule(func() {
  151. fn(task.key, task.value)
  152. })
  153. }
  154. }
  155. }
  156. }
  157. func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
  158. steps := int(d / tw.interval)
  159. pos = (tw.tickedPos + steps) % tw.numSlots
  160. circle = (steps - 1) / tw.numSlots
  161. return
  162. }
  163. func (tw *TimingWheel) initSlots() {
  164. for i := 0; i < tw.numSlots; i++ {
  165. tw.slots[i] = list.New()
  166. }
  167. }
  168. func (tw *TimingWheel) moveTask(task baseEntry) {
  169. val, ok := tw.timers.Get(task.key)
  170. if !ok {
  171. return
  172. }
  173. timer := val.(*positionEntry)
  174. if task.delay < tw.interval {
  175. threading.GoSafe(func() {
  176. tw.execute(timer.item.key, timer.item.value)
  177. })
  178. return
  179. }
  180. pos, circle := tw.getPositionAndCircle(task.delay)
  181. if pos >= timer.pos {
  182. timer.item.circle = circle
  183. timer.item.diff = pos - timer.pos
  184. } else if circle > 0 {
  185. circle--
  186. timer.item.circle = circle
  187. timer.item.diff = tw.numSlots + pos - timer.pos
  188. } else {
  189. timer.item.removed = true
  190. newItem := &timingEntry{
  191. baseEntry: task,
  192. value: timer.item.value,
  193. }
  194. tw.slots[pos].PushBack(newItem)
  195. tw.setTimerPosition(pos, newItem)
  196. }
  197. }
  198. func (tw *TimingWheel) onTick() {
  199. tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots
  200. l := tw.slots[tw.tickedPos]
  201. tw.scanAndRunTasks(l)
  202. }
  203. func (tw *TimingWheel) removeTask(key any) {
  204. val, ok := tw.timers.Get(key)
  205. if !ok {
  206. return
  207. }
  208. timer := val.(*positionEntry)
  209. timer.item.removed = true
  210. tw.timers.Del(key)
  211. }
  212. func (tw *TimingWheel) run() {
  213. for {
  214. select {
  215. case <-tw.ticker.Chan():
  216. tw.onTick()
  217. case task := <-tw.setChannel:
  218. tw.setTask(&task)
  219. case key := <-tw.removeChannel:
  220. tw.removeTask(key)
  221. case task := <-tw.moveChannel:
  222. tw.moveTask(task)
  223. case fn := <-tw.drainChannel:
  224. tw.drainAll(fn)
  225. case <-tw.stopChannel:
  226. tw.ticker.Stop()
  227. return
  228. }
  229. }
  230. }
  231. func (tw *TimingWheel) runTasks(tasks []timingTask) {
  232. if len(tasks) == 0 {
  233. return
  234. }
  235. go func() {
  236. for i := range tasks {
  237. threading.RunSafe(func() {
  238. tw.execute(tasks[i].key, tasks[i].value)
  239. })
  240. }
  241. }()
  242. }
  243. func (tw *TimingWheel) scanAndRunTasks(l *list.List) {
  244. var tasks []timingTask
  245. for e := l.Front(); e != nil; {
  246. task := e.Value.(*timingEntry)
  247. if task.removed {
  248. next := e.Next()
  249. l.Remove(e)
  250. e = next
  251. continue
  252. } else if task.circle > 0 {
  253. task.circle--
  254. e = e.Next()
  255. continue
  256. } else if task.diff > 0 {
  257. next := e.Next()
  258. l.Remove(e)
  259. // (tw.tickedPos+task.diff)%tw.numSlots
  260. // cannot be the same value of tw.tickedPos
  261. pos := (tw.tickedPos + task.diff) % tw.numSlots
  262. tw.slots[pos].PushBack(task)
  263. tw.setTimerPosition(pos, task)
  264. task.diff = 0
  265. e = next
  266. continue
  267. }
  268. tasks = append(tasks, timingTask{
  269. key: task.key,
  270. value: task.value,
  271. })
  272. next := e.Next()
  273. l.Remove(e)
  274. tw.timers.Del(task.key)
  275. e = next
  276. }
  277. tw.runTasks(tasks)
  278. }
  279. func (tw *TimingWheel) setTask(task *timingEntry) {
  280. if task.delay < tw.interval {
  281. task.delay = tw.interval
  282. }
  283. if val, ok := tw.timers.Get(task.key); ok {
  284. entry := val.(*positionEntry)
  285. entry.item.value = task.value
  286. tw.moveTask(task.baseEntry)
  287. } else {
  288. pos, circle := tw.getPositionAndCircle(task.delay)
  289. task.circle = circle
  290. tw.slots[pos].PushBack(task)
  291. tw.setTimerPosition(pos, task)
  292. }
  293. }
  294. func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) {
  295. if val, ok := tw.timers.Get(task.key); ok {
  296. timer := val.(*positionEntry)
  297. timer.item = task
  298. timer.pos = pos
  299. } else {
  300. tw.timers.Set(task.key, &positionEntry{
  301. pos: pos,
  302. item: task,
  303. })
  304. }
  305. }