roundrobinbalancer.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package internal
  2. import (
  3. "math/rand"
  4. "time"
  5. "zero/core/logx"
  6. )
  7. type roundRobinBalancer struct {
  8. *baseBalancer
  9. conns []serverConn
  10. index int
  11. }
  12. func NewRoundRobinBalancer(dialFn DialFn, closeFn CloseFn, exclusive bool) *roundRobinBalancer {
  13. balancer := new(roundRobinBalancer)
  14. balancer.baseBalancer = newBaseBalancer(dialFn, closeFn, exclusive)
  15. return balancer
  16. }
  17. func (b *roundRobinBalancer) AddConn(kv KV) error {
  18. var conn interface{}
  19. prev, found := b.addKv(kv.Key, kv.Val)
  20. if found {
  21. conn = b.handlePrevious(prev, kv.Val)
  22. }
  23. if conn == nil {
  24. var err error
  25. conn, err = b.dialFn(kv.Val)
  26. if err != nil {
  27. b.removeKv(kv.Key)
  28. return err
  29. }
  30. }
  31. b.lock.Lock()
  32. defer b.lock.Unlock()
  33. b.conns = append(b.conns, serverConn{
  34. key: kv.Key,
  35. conn: conn,
  36. })
  37. b.notify(kv.Key)
  38. return nil
  39. }
  40. func (b *roundRobinBalancer) handlePrevious(prev []string, server string) interface{} {
  41. if len(prev) == 0 {
  42. return nil
  43. }
  44. b.lock.Lock()
  45. defer b.lock.Unlock()
  46. if b.exclusive {
  47. for _, item := range prev {
  48. conns := b.conns[:0]
  49. for _, each := range b.conns {
  50. if each.key == item {
  51. if err := b.closeFn(server, each.conn); err != nil {
  52. logx.Error(err)
  53. }
  54. } else {
  55. conns = append(conns, each)
  56. }
  57. }
  58. b.conns = conns
  59. }
  60. } else {
  61. for _, each := range b.conns {
  62. if each.key == prev[0] {
  63. return each.conn
  64. }
  65. }
  66. }
  67. return nil
  68. }
  69. func (b *roundRobinBalancer) initialize() {
  70. rand.Seed(time.Now().UnixNano())
  71. if len(b.conns) > 0 {
  72. b.index = rand.Intn(len(b.conns))
  73. }
  74. }
  75. func (b *roundRobinBalancer) IsEmpty() bool {
  76. b.lock.Lock()
  77. empty := len(b.conns) == 0
  78. b.lock.Unlock()
  79. return empty
  80. }
  81. func (b *roundRobinBalancer) Next(...string) (interface{}, bool) {
  82. b.lock.Lock()
  83. defer b.lock.Unlock()
  84. if len(b.conns) == 0 {
  85. return nil, false
  86. }
  87. b.index = (b.index + 1) % len(b.conns)
  88. return b.conns[b.index].conn, true
  89. }
  90. func (b *roundRobinBalancer) notify(key string) {
  91. if b.listener == nil {
  92. return
  93. }
  94. // b.servers has the format of map[conn][]key
  95. var keys []string
  96. var values []string
  97. for k, v := range b.servers {
  98. values = append(values, k)
  99. keys = append(keys, v...)
  100. }
  101. b.listener.OnUpdate(keys, values, key)
  102. }
  103. func (b *roundRobinBalancer) RemoveKey(key string) {
  104. server, keep := b.removeKv(key)
  105. b.lock.Lock()
  106. defer b.lock.Unlock()
  107. conns := b.conns[:0]
  108. for _, conn := range b.conns {
  109. if conn.key == key {
  110. // there are other keys assocated with the conn, don't close the conn.
  111. if keep {
  112. continue
  113. }
  114. if err := b.closeFn(server, conn.conn); err != nil {
  115. logx.Error(err)
  116. }
  117. } else {
  118. conns = append(conns, conn)
  119. }
  120. }
  121. b.conns = conns
  122. // notify without new key
  123. b.notify("")
  124. }