1
0

mapreduce.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. package mr
  2. import (
  3. "context"
  4. "errors"
  5. "sync"
  6. "sync/atomic"
  7. "github.com/wuntsong-org/go-zero-plus/core/errorx"
  8. )
  9. const (
  10. defaultWorkers = 16
  11. minWorkers = 1
  12. )
  13. var (
  14. // ErrCancelWithNil is an error that mapreduce was cancelled with nil.
  15. ErrCancelWithNil = errors.New("mapreduce cancelled with nil")
  16. // ErrReduceNoOutput is an error that reduce did not output a value.
  17. ErrReduceNoOutput = errors.New("reduce not writing value")
  18. )
  19. type (
  20. // ForEachFunc is used to do element processing, but no output.
  21. ForEachFunc[T any] func(item T)
  22. // GenerateFunc is used to let callers send elements into source.
  23. GenerateFunc[T any] func(source chan<- T)
  24. // MapFunc is used to do element processing and write the output to writer.
  25. MapFunc[T, U any] func(item T, writer Writer[U])
  26. // MapperFunc is used to do element processing and write the output to writer,
  27. // use cancel func to cancel the processing.
  28. MapperFunc[T, U any] func(item T, writer Writer[U], cancel func(error))
  29. // ReducerFunc is used to reduce all the mapping output and write to writer,
  30. // use cancel func to cancel the processing.
  31. ReducerFunc[U, V any] func(pipe <-chan U, writer Writer[V], cancel func(error))
  32. // VoidReducerFunc is used to reduce all the mapping output, but no output.
  33. // Use cancel func to cancel the processing.
  34. VoidReducerFunc[U any] func(pipe <-chan U, cancel func(error))
  35. // Option defines the method to customize the mapreduce.
  36. Option func(opts *mapReduceOptions)
  37. mapperContext[T, U any] struct {
  38. ctx context.Context
  39. mapper MapFunc[T, U]
  40. source <-chan T
  41. panicChan *onceChan
  42. collector chan<- U
  43. doneChan <-chan struct{}
  44. workers int
  45. }
  46. mapReduceOptions struct {
  47. ctx context.Context
  48. workers int
  49. }
  50. // Writer interface wraps Write method.
  51. Writer[T any] interface {
  52. Write(v T)
  53. }
  54. )
  55. // Finish runs fns parallelly, cancelled on any error.
  56. func Finish(fns ...func() error) error {
  57. if len(fns) == 0 {
  58. return nil
  59. }
  60. return MapReduceVoid(func(source chan<- func() error) {
  61. for _, fn := range fns {
  62. source <- fn
  63. }
  64. }, func(fn func() error, writer Writer[any], cancel func(error)) {
  65. if err := fn(); err != nil {
  66. cancel(err)
  67. }
  68. }, func(pipe <-chan any, cancel func(error)) {
  69. }, WithWorkers(len(fns)))
  70. }
  71. // FinishVoid runs fns parallelly.
  72. func FinishVoid(fns ...func()) {
  73. if len(fns) == 0 {
  74. return
  75. }
  76. ForEach(func(source chan<- func()) {
  77. for _, fn := range fns {
  78. source <- fn
  79. }
  80. }, func(fn func()) {
  81. fn()
  82. }, WithWorkers(len(fns)))
  83. }
  84. // ForEach maps all elements from given generate but no output.
  85. func ForEach[T any](generate GenerateFunc[T], mapper ForEachFunc[T], opts ...Option) {
  86. options := buildOptions(opts...)
  87. panicChan := &onceChan{channel: make(chan any)}
  88. source := buildSource(generate, panicChan)
  89. collector := make(chan any)
  90. done := make(chan struct{})
  91. go executeMappers(mapperContext[T, any]{
  92. ctx: options.ctx,
  93. mapper: func(item T, _ Writer[any]) {
  94. mapper(item)
  95. },
  96. source: source,
  97. panicChan: panicChan,
  98. collector: collector,
  99. doneChan: done,
  100. workers: options.workers,
  101. })
  102. for {
  103. select {
  104. case v := <-panicChan.channel:
  105. panic(v)
  106. case _, ok := <-collector:
  107. if !ok {
  108. return
  109. }
  110. }
  111. }
  112. }
  113. // MapReduce maps all elements generated from given generate func,
  114. // and reduces the output elements with given reducer.
  115. func MapReduce[T, U, V any](generate GenerateFunc[T], mapper MapperFunc[T, U], reducer ReducerFunc[U, V],
  116. opts ...Option) (V, error) {
  117. panicChan := &onceChan{channel: make(chan any)}
  118. source := buildSource(generate, panicChan)
  119. return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
  120. }
  121. // MapReduceChan maps all elements from source, and reduce the output elements with given reducer.
  122. func MapReduceChan[T, U, V any](source <-chan T, mapper MapperFunc[T, U], reducer ReducerFunc[U, V],
  123. opts ...Option) (V, error) {
  124. panicChan := &onceChan{channel: make(chan any)}
  125. return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
  126. }
  127. // mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
  128. func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, mapper MapperFunc[T, U],
  129. reducer ReducerFunc[U, V], opts ...Option) (val V, err error) {
  130. options := buildOptions(opts...)
  131. // output is used to write the final result
  132. output := make(chan V)
  133. defer func() {
  134. // reducer can only write once, if more, panic
  135. for range output {
  136. panic("more than one element written in reducer")
  137. }
  138. }()
  139. // collector is used to collect data from mapper, and consume in reducer
  140. collector := make(chan U, options.workers)
  141. // if done is closed, all mappers and reducer should stop processing
  142. done := make(chan struct{})
  143. writer := newGuardedWriter(options.ctx, output, done)
  144. var closeOnce sync.Once
  145. // use atomic type to avoid data race
  146. var retErr errorx.AtomicError
  147. finish := func() {
  148. closeOnce.Do(func() {
  149. close(done)
  150. close(output)
  151. })
  152. }
  153. cancel := once(func(err error) {
  154. if err != nil {
  155. retErr.Set(err)
  156. } else {
  157. retErr.Set(ErrCancelWithNil)
  158. }
  159. drain(source)
  160. finish()
  161. })
  162. go func() {
  163. defer func() {
  164. drain(collector)
  165. if r := recover(); r != nil {
  166. panicChan.write(r)
  167. }
  168. finish()
  169. }()
  170. reducer(collector, writer, cancel)
  171. }()
  172. go executeMappers(mapperContext[T, U]{
  173. ctx: options.ctx,
  174. mapper: func(item T, w Writer[U]) {
  175. mapper(item, w, cancel)
  176. },
  177. source: source,
  178. panicChan: panicChan,
  179. collector: collector,
  180. doneChan: done,
  181. workers: options.workers,
  182. })
  183. select {
  184. case <-options.ctx.Done():
  185. cancel(context.DeadlineExceeded)
  186. err = context.DeadlineExceeded
  187. case v := <-panicChan.channel:
  188. // drain output here, otherwise for loop panic in defer
  189. drain(output)
  190. panic(v)
  191. case v, ok := <-output:
  192. if e := retErr.Load(); e != nil {
  193. err = e
  194. } else if ok {
  195. val = v
  196. } else {
  197. err = ErrReduceNoOutput
  198. }
  199. }
  200. return
  201. }
  202. // MapReduceVoid maps all elements generated from given generate,
  203. // and reduce the output elements with given reducer.
  204. func MapReduceVoid[T, U any](generate GenerateFunc[T], mapper MapperFunc[T, U],
  205. reducer VoidReducerFunc[U], opts ...Option) error {
  206. _, err := MapReduce(generate, mapper, func(input <-chan U, writer Writer[any], cancel func(error)) {
  207. reducer(input, cancel)
  208. }, opts...)
  209. if errors.Is(err, ErrReduceNoOutput) {
  210. return nil
  211. }
  212. return err
  213. }
  214. // WithContext customizes a mapreduce processing accepts a given ctx.
  215. func WithContext(ctx context.Context) Option {
  216. return func(opts *mapReduceOptions) {
  217. opts.ctx = ctx
  218. }
  219. }
  220. // WithWorkers customizes a mapreduce processing with given workers.
  221. func WithWorkers(workers int) Option {
  222. return func(opts *mapReduceOptions) {
  223. if workers < minWorkers {
  224. opts.workers = minWorkers
  225. } else {
  226. opts.workers = workers
  227. }
  228. }
  229. }
  230. func buildOptions(opts ...Option) *mapReduceOptions {
  231. options := newOptions()
  232. for _, opt := range opts {
  233. opt(options)
  234. }
  235. return options
  236. }
  237. func buildSource[T any](generate GenerateFunc[T], panicChan *onceChan) chan T {
  238. source := make(chan T)
  239. go func() {
  240. defer func() {
  241. if r := recover(); r != nil {
  242. panicChan.write(r)
  243. }
  244. close(source)
  245. }()
  246. generate(source)
  247. }()
  248. return source
  249. }
  250. // drain drains the channel.
  251. func drain[T any](channel <-chan T) {
  252. // drain the channel
  253. for range channel {
  254. }
  255. }
  256. func executeMappers[T, U any](mCtx mapperContext[T, U]) {
  257. var wg sync.WaitGroup
  258. defer func() {
  259. wg.Wait()
  260. close(mCtx.collector)
  261. drain(mCtx.source)
  262. }()
  263. var failed int32
  264. pool := make(chan struct{}, mCtx.workers)
  265. writer := newGuardedWriter(mCtx.ctx, mCtx.collector, mCtx.doneChan)
  266. for atomic.LoadInt32(&failed) == 0 {
  267. select {
  268. case <-mCtx.ctx.Done():
  269. return
  270. case <-mCtx.doneChan:
  271. return
  272. case pool <- struct{}{}:
  273. item, ok := <-mCtx.source
  274. if !ok {
  275. <-pool
  276. return
  277. }
  278. wg.Add(1)
  279. go func() {
  280. defer func() {
  281. if r := recover(); r != nil {
  282. atomic.AddInt32(&failed, 1)
  283. mCtx.panicChan.write(r)
  284. }
  285. wg.Done()
  286. <-pool
  287. }()
  288. mCtx.mapper(item, writer)
  289. }()
  290. }
  291. }
  292. }
  293. func newOptions() *mapReduceOptions {
  294. return &mapReduceOptions{
  295. ctx: context.Background(),
  296. workers: defaultWorkers,
  297. }
  298. }
  299. func once(fn func(error)) func(error) {
  300. once := new(sync.Once)
  301. return func(err error) {
  302. once.Do(func() {
  303. fn(err)
  304. })
  305. }
  306. }
  307. type guardedWriter[T any] struct {
  308. ctx context.Context
  309. channel chan<- T
  310. done <-chan struct{}
  311. }
  312. func newGuardedWriter[T any](ctx context.Context, channel chan<- T, done <-chan struct{}) guardedWriter[T] {
  313. return guardedWriter[T]{
  314. ctx: ctx,
  315. channel: channel,
  316. done: done,
  317. }
  318. }
  319. func (gw guardedWriter[T]) Write(v T) {
  320. select {
  321. case <-gw.ctx.Done():
  322. return
  323. case <-gw.done:
  324. return
  325. default:
  326. gw.channel <- v
  327. }
  328. }
  329. type onceChan struct {
  330. channel chan any
  331. wrote int32
  332. }
  333. func (oc *onceChan) write(val any) {
  334. if atomic.CompareAndSwapInt32(&oc.wrote, 0, 1) {
  335. oc.channel <- val
  336. }
  337. }