timingwheel.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. package collection
  2. import (
  3. "container/list"
  4. "errors"
  5. "fmt"
  6. "time"
  7. "github.com/zeromicro/go-zero/core/lang"
  8. "github.com/zeromicro/go-zero/core/threading"
  9. "github.com/zeromicro/go-zero/core/timex"
  10. )
  11. const drainWorkers = 8
  12. var (
  13. ErrClosed = errors.New("TimingWheel is closed already")
  14. ErrArgument = errors.New("incorrect task argument")
  15. )
  16. type (
  17. // Execute defines the method to execute the task.
  18. Execute func(key, value interface{})
  19. // A TimingWheel is a timing wheel object to schedule tasks.
  20. TimingWheel struct {
  21. interval time.Duration
  22. ticker timex.Ticker
  23. slots []*list.List
  24. timers *SafeMap
  25. tickedPos int
  26. numSlots int
  27. execute Execute
  28. setChannel chan timingEntry
  29. moveChannel chan baseEntry
  30. removeChannel chan interface{}
  31. drainChannel chan func(key, value interface{})
  32. stopChannel chan lang.PlaceholderType
  33. }
  34. timingEntry struct {
  35. baseEntry
  36. value interface{}
  37. circle int
  38. diff int
  39. removed bool
  40. }
  41. baseEntry struct {
  42. delay time.Duration
  43. key interface{}
  44. }
  45. positionEntry struct {
  46. pos int
  47. item *timingEntry
  48. }
  49. timingTask struct {
  50. key interface{}
  51. value interface{}
  52. }
  53. )
  54. // NewTimingWheel returns a TimingWheel.
  55. func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) {
  56. if interval <= 0 || numSlots <= 0 || execute == nil {
  57. return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p",
  58. interval, numSlots, execute)
  59. }
  60. return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval))
  61. }
  62. func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute,
  63. ticker timex.Ticker) (*TimingWheel, error) {
  64. tw := &TimingWheel{
  65. interval: interval,
  66. ticker: ticker,
  67. slots: make([]*list.List, numSlots),
  68. timers: NewSafeMap(),
  69. tickedPos: numSlots - 1, // at previous virtual circle
  70. execute: execute,
  71. numSlots: numSlots,
  72. setChannel: make(chan timingEntry),
  73. moveChannel: make(chan baseEntry),
  74. removeChannel: make(chan interface{}),
  75. drainChannel: make(chan func(key, value interface{})),
  76. stopChannel: make(chan lang.PlaceholderType),
  77. }
  78. tw.initSlots()
  79. go tw.run()
  80. return tw, nil
  81. }
  82. // Drain drains all items and executes them.
  83. func (tw *TimingWheel) Drain(fn func(key, value interface{})) error {
  84. select {
  85. case tw.drainChannel <- fn:
  86. return nil
  87. case <-tw.stopChannel:
  88. return ErrClosed
  89. }
  90. }
  91. // MoveTimer moves the task with the given key to the given delay.
  92. func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error {
  93. if delay <= 0 || key == nil {
  94. return ErrArgument
  95. }
  96. select {
  97. case tw.moveChannel <- baseEntry{
  98. delay: delay,
  99. key: key,
  100. }:
  101. return nil
  102. case <-tw.stopChannel:
  103. return ErrClosed
  104. }
  105. }
  106. // RemoveTimer removes the task with the given key.
  107. func (tw *TimingWheel) RemoveTimer(key interface{}) error {
  108. if key == nil {
  109. return ErrArgument
  110. }
  111. select {
  112. case tw.removeChannel <- key:
  113. return nil
  114. case <-tw.stopChannel:
  115. return ErrClosed
  116. }
  117. }
  118. // SetTimer sets the task value with the given key to the delay.
  119. func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error {
  120. if delay <= 0 || key == nil {
  121. return ErrArgument
  122. }
  123. select {
  124. case tw.setChannel <- timingEntry{
  125. baseEntry: baseEntry{
  126. delay: delay,
  127. key: key,
  128. },
  129. value: value,
  130. }:
  131. return nil
  132. case <-tw.stopChannel:
  133. return ErrClosed
  134. }
  135. }
  136. // Stop stops tw. No more actions after stopping a TimingWheel.
  137. func (tw *TimingWheel) Stop() {
  138. close(tw.stopChannel)
  139. }
  140. func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
  141. runner := threading.NewTaskRunner(drainWorkers)
  142. for _, slot := range tw.slots {
  143. for e := slot.Front(); e != nil; {
  144. task := e.Value.(*timingEntry)
  145. next := e.Next()
  146. slot.Remove(e)
  147. e = next
  148. if !task.removed {
  149. runner.Schedule(func() {
  150. fn(task.key, task.value)
  151. })
  152. }
  153. }
  154. }
  155. }
  156. func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
  157. steps := int(d / tw.interval)
  158. pos = (tw.tickedPos + steps) % tw.numSlots
  159. circle = (steps - 1) / tw.numSlots
  160. return
  161. }
  162. func (tw *TimingWheel) initSlots() {
  163. for i := 0; i < tw.numSlots; i++ {
  164. tw.slots[i] = list.New()
  165. }
  166. }
  167. func (tw *TimingWheel) moveTask(task baseEntry) {
  168. val, ok := tw.timers.Get(task.key)
  169. if !ok {
  170. return
  171. }
  172. timer := val.(*positionEntry)
  173. if task.delay < tw.interval {
  174. threading.GoSafe(func() {
  175. tw.execute(timer.item.key, timer.item.value)
  176. })
  177. return
  178. }
  179. pos, circle := tw.getPositionAndCircle(task.delay)
  180. if pos >= timer.pos {
  181. timer.item.circle = circle
  182. timer.item.diff = pos - timer.pos
  183. } else if circle > 0 {
  184. circle--
  185. timer.item.circle = circle
  186. timer.item.diff = tw.numSlots + pos - timer.pos
  187. } else {
  188. timer.item.removed = true
  189. newItem := &timingEntry{
  190. baseEntry: task,
  191. value: timer.item.value,
  192. }
  193. tw.slots[pos].PushBack(newItem)
  194. tw.setTimerPosition(pos, newItem)
  195. }
  196. }
  197. func (tw *TimingWheel) onTick() {
  198. tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots
  199. l := tw.slots[tw.tickedPos]
  200. tw.scanAndRunTasks(l)
  201. }
  202. func (tw *TimingWheel) removeTask(key interface{}) {
  203. val, ok := tw.timers.Get(key)
  204. if !ok {
  205. return
  206. }
  207. timer := val.(*positionEntry)
  208. timer.item.removed = true
  209. tw.timers.Del(key)
  210. }
  211. func (tw *TimingWheel) run() {
  212. for {
  213. select {
  214. case <-tw.ticker.Chan():
  215. tw.onTick()
  216. case task := <-tw.setChannel:
  217. tw.setTask(&task)
  218. case key := <-tw.removeChannel:
  219. tw.removeTask(key)
  220. case task := <-tw.moveChannel:
  221. tw.moveTask(task)
  222. case fn := <-tw.drainChannel:
  223. tw.drainAll(fn)
  224. case <-tw.stopChannel:
  225. tw.ticker.Stop()
  226. return
  227. }
  228. }
  229. }
  230. func (tw *TimingWheel) runTasks(tasks []timingTask) {
  231. if len(tasks) == 0 {
  232. return
  233. }
  234. go func() {
  235. for i := range tasks {
  236. threading.RunSafe(func() {
  237. tw.execute(tasks[i].key, tasks[i].value)
  238. })
  239. }
  240. }()
  241. }
  242. func (tw *TimingWheel) scanAndRunTasks(l *list.List) {
  243. var tasks []timingTask
  244. for e := l.Front(); e != nil; {
  245. task := e.Value.(*timingEntry)
  246. if task.removed {
  247. next := e.Next()
  248. l.Remove(e)
  249. e = next
  250. continue
  251. } else if task.circle > 0 {
  252. task.circle--
  253. e = e.Next()
  254. continue
  255. } else if task.diff > 0 {
  256. next := e.Next()
  257. l.Remove(e)
  258. // (tw.tickedPos+task.diff)%tw.numSlots
  259. // cannot be the same value of tw.tickedPos
  260. pos := (tw.tickedPos + task.diff) % tw.numSlots
  261. tw.slots[pos].PushBack(task)
  262. tw.setTimerPosition(pos, task)
  263. task.diff = 0
  264. e = next
  265. continue
  266. }
  267. tasks = append(tasks, timingTask{
  268. key: task.key,
  269. value: task.value,
  270. })
  271. next := e.Next()
  272. l.Remove(e)
  273. tw.timers.Del(task.key)
  274. e = next
  275. }
  276. tw.runTasks(tasks)
  277. }
  278. func (tw *TimingWheel) setTask(task *timingEntry) {
  279. if task.delay < tw.interval {
  280. task.delay = tw.interval
  281. }
  282. if val, ok := tw.timers.Get(task.key); ok {
  283. entry := val.(*positionEntry)
  284. entry.item.value = task.value
  285. tw.moveTask(task.baseEntry)
  286. } else {
  287. pos, circle := tw.getPositionAndCircle(task.delay)
  288. task.circle = circle
  289. tw.slots[pos].PushBack(task)
  290. tw.setTimerPosition(pos, task)
  291. }
  292. }
  293. func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) {
  294. if val, ok := tw.timers.Get(task.key); ok {
  295. timer := val.(*positionEntry)
  296. timer.item = task
  297. timer.pos = pos
  298. } else {
  299. tw.timers.Set(task.key, &positionEntry{
  300. pos: pos,
  301. item: task,
  302. })
  303. }
  304. }