Browse Source

feat: support context in MapReduce (#1368)

Kevin Wan 3 years ago
parent
commit
c0647f0719
2 changed files with 71 additions and 8 deletions
  1. 26 8
      core/mr/mapreduce.go
  2. 45 0
      core/mr/mapreduce_test.go

+ 26 - 8
core/mr/mapreduce.go

@@ -1,6 +1,7 @@
 package mr
 package mr
 
 
 import (
 import (
+	"context"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"sync"
 	"sync"
@@ -43,6 +44,7 @@ type (
 	Option func(opts *mapReduceOptions)
 	Option func(opts *mapReduceOptions)
 
 
 	mapReduceOptions struct {
 	mapReduceOptions struct {
+		ctx     context.Context
 		workers int
 		workers int
 	}
 	}
 
 
@@ -95,14 +97,15 @@ func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{}
 	collector := make(chan interface{}, options.workers)
 	collector := make(chan interface{}, options.workers)
 	done := syncx.NewDoneChan()
 	done := syncx.NewDoneChan()
 
 
-	go executeMappers(mapper, source, collector, done.Done(), options.workers)
+	go executeMappers(options.ctx, mapper, source, collector, done.Done(), options.workers)
 
 
 	return collector
 	return collector
 }
 }
 
 
 // MapReduce maps all elements generated from given generate func,
 // MapReduce maps all elements generated from given generate func,
 // and reduces the output elements with given reducer.
 // and reduces the output elements with given reducer.
-func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) {
+func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc,
+	opts ...Option) (interface{}, error) {
 	source := buildSource(generate)
 	source := buildSource(generate)
 	return MapReduceWithSource(source, mapper, reducer, opts...)
 	return MapReduceWithSource(source, mapper, reducer, opts...)
 }
 }
@@ -120,7 +123,7 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
 
 
 	collector := make(chan interface{}, options.workers)
 	collector := make(chan interface{}, options.workers)
 	done := syncx.NewDoneChan()
 	done := syncx.NewDoneChan()
-	writer := newGuardedWriter(output, done.Done())
+	writer := newGuardedWriter(options.ctx, output, done.Done())
 	var closeOnce sync.Once
 	var closeOnce sync.Once
 	var retErr errorx.AtomicError
 	var retErr errorx.AtomicError
 	finish := func() {
 	finish := func() {
@@ -154,7 +157,7 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
 		reducer(collector, writer, cancel)
 		reducer(collector, writer, cancel)
 	}()
 	}()
 
 
-	go executeMappers(func(item interface{}, w Writer) {
+	go executeMappers(options.ctx, func(item interface{}, w Writer) {
 		mapper(item, w, cancel)
 		mapper(item, w, cancel)
 	}, source, collector, done.Done(), options.workers)
 	}, source, collector, done.Done(), options.workers)
 
 
@@ -187,6 +190,13 @@ func MapVoid(generate GenerateFunc, mapper VoidMapFunc, opts ...Option) {
 	}, opts...))
 	}, opts...))
 }
 }
 
 
+// WithContext customizes a mapreduce processing accepts a given ctx.
+func WithContext(ctx context.Context) Option {
+	return func(opts *mapReduceOptions) {
+		opts.ctx = ctx
+	}
+}
+
 // WithWorkers customizes a mapreduce processing with given workers.
 // WithWorkers customizes a mapreduce processing with given workers.
 func WithWorkers(workers int) Option {
 func WithWorkers(workers int) Option {
 	return func(opts *mapReduceOptions) {
 	return func(opts *mapReduceOptions) {
@@ -224,8 +234,8 @@ func drain(channel <-chan interface{}) {
 	}
 	}
 }
 }
 
 
-func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- interface{},
-	done <-chan lang.PlaceholderType, workers int) {
+func executeMappers(ctx context.Context, mapper MapFunc, input <-chan interface{},
+	collector chan<- interface{}, done <-chan lang.PlaceholderType, workers int) {
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
 	defer func() {
 	defer func() {
 		wg.Wait()
 		wg.Wait()
@@ -233,9 +243,11 @@ func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- i
 	}()
 	}()
 
 
 	pool := make(chan lang.PlaceholderType, workers)
 	pool := make(chan lang.PlaceholderType, workers)
-	writer := newGuardedWriter(collector, done)
+	writer := newGuardedWriter(ctx, collector, done)
 	for {
 	for {
 		select {
 		select {
+		case <-ctx.Done():
+			return
 		case <-done:
 		case <-done:
 			return
 			return
 		case pool <- lang.Placeholder:
 		case pool <- lang.Placeholder:
@@ -261,6 +273,7 @@ func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- i
 
 
 func newOptions() *mapReduceOptions {
 func newOptions() *mapReduceOptions {
 	return &mapReduceOptions{
 	return &mapReduceOptions{
+		ctx:     context.Background(),
 		workers: defaultWorkers,
 		workers: defaultWorkers,
 	}
 	}
 }
 }
@@ -275,12 +288,15 @@ func once(fn func(error)) func(error) {
 }
 }
 
 
 type guardedWriter struct {
 type guardedWriter struct {
+	ctx     context.Context
 	channel chan<- interface{}
 	channel chan<- interface{}
 	done    <-chan lang.PlaceholderType
 	done    <-chan lang.PlaceholderType
 }
 }
 
 
-func newGuardedWriter(channel chan<- interface{}, done <-chan lang.PlaceholderType) guardedWriter {
+func newGuardedWriter(ctx context.Context, channel chan<- interface{},
+	done <-chan lang.PlaceholderType) guardedWriter {
 	return guardedWriter{
 	return guardedWriter{
+		ctx:     ctx,
 		channel: channel,
 		channel: channel,
 		done:    done,
 		done:    done,
 	}
 	}
@@ -288,6 +304,8 @@ func newGuardedWriter(channel chan<- interface{}, done <-chan lang.PlaceholderTy
 
 
 func (gw guardedWriter) Write(v interface{}) {
 func (gw guardedWriter) Write(v interface{}) {
 	select {
 	select {
+	case <-gw.ctx.Done():
+		return
 	case <-gw.done:
 	case <-gw.done:
 		return
 		return
 	default:
 	default:

+ 45 - 0
core/mr/mapreduce_test.go

@@ -1,6 +1,7 @@
 package mr
 package mr
 
 
 import (
 import (
+	"context"
 	"errors"
 	"errors"
 	"io/ioutil"
 	"io/ioutil"
 	"log"
 	"log"
@@ -410,6 +411,50 @@ func TestMapReduceWithoutReducerWrite(t *testing.T) {
 	assert.Nil(t, res)
 	assert.Nil(t, res)
 }
 }
 
 
+func TestMapReduceVoidPanicInReducer(t *testing.T) {
+	const message = "foo"
+	var done syncx.AtomicBool
+	err := MapReduceVoid(func(source chan<- interface{}) {
+		for i := 0; i < defaultWorkers*2; i++ {
+			source <- i
+		}
+		done.Set(true)
+	}, func(item interface{}, writer Writer, cancel func(error)) {
+		i := item.(int)
+		writer.Write(i)
+	}, func(pipe <-chan interface{}, cancel func(error)) {
+		panic(message)
+	}, WithWorkers(1))
+	assert.NotNil(t, err)
+	assert.Equal(t, message, err.Error())
+	assert.True(t, done.True())
+}
+
+func TestMapReduceWithContext(t *testing.T) {
+	var done syncx.AtomicBool
+	var result []int
+	ctx, cancel := context.WithCancel(context.Background())
+	err := MapReduceVoid(func(source chan<- interface{}) {
+		for i := 0; i < defaultWorkers*2; i++ {
+			source <- i
+		}
+		done.Set(true)
+	}, func(item interface{}, writer Writer, c func(error)) {
+		i := item.(int)
+		if i == defaultWorkers/2 {
+			cancel()
+		}
+		writer.Write(i)
+	}, func(pipe <-chan interface{}, cancel func(error)) {
+		for item := range pipe {
+			i := item.(int)
+			result = append(result, i)
+		}
+	}, WithContext(ctx))
+	assert.NotNil(t, err)
+	assert.Equal(t, ErrReduceNoOutput, err)
+}
+
 func BenchmarkMapReduce(b *testing.B) {
 func BenchmarkMapReduce(b *testing.B) {
 	b.ReportAllocs()
 	b.ReportAllocs()