ソースを参照

feat: handling panic in mapreduce, panic in calling goroutine, not inside goroutines (#1490)

* feat: handle panic

* chore: update fuzz test

* chore: optimize square sum algorithm
Kevin Wan 3 年 前
コミット
14a902c1a7
3 ファイル変更372 行追加167 行削除
  1. 116 39
      core/mr/mapreduce.go
  2. 78 0
      core/mr/mapreduce_fuzz_test.go
  3. 178 128
      core/mr/mapreduce_test.go

+ 116 - 39
core/mr/mapreduce.go

@@ -3,12 +3,11 @@ package mr
 import (
 	"context"
 	"errors"
-	"fmt"
 	"sync"
+	"sync/atomic"
 
 	"github.com/zeromicro/go-zero/core/errorx"
 	"github.com/zeromicro/go-zero/core/lang"
-	"github.com/zeromicro/go-zero/core/threading"
 )
 
 const (
@@ -42,6 +41,16 @@ type (
 	// Option defines the method to customize the mapreduce.
 	Option func(opts *mapReduceOptions)
 
+	mapperContext struct {
+		ctx       context.Context
+		mapper    MapFunc
+		source    <-chan interface{}
+		panicChan *onceChan
+		collector chan<- interface{}
+		doneChan  <-chan lang.PlaceholderType
+		workers   int
+	}
+
 	mapReduceOptions struct {
 		ctx     context.Context
 		workers int
@@ -90,46 +99,72 @@ func FinishVoid(fns ...func()) {
 
 // ForEach maps all elements from given generate but no output.
 func ForEach(generate GenerateFunc, mapper ForEachFunc, opts ...Option) {
-	drain(Map(generate, func(item interface{}, writer Writer) {
-		mapper(item)
-	}, opts...))
-}
-
-// Map maps all elements generated from given generate func, and returns an output channel.
-func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} {
 	options := buildOptions(opts...)
-	source := buildSource(generate)
+	panicChan := &onceChan{channel: make(chan interface{})}
+	source := buildSource(generate, panicChan)
 	collector := make(chan interface{}, options.workers)
 	done := make(chan lang.PlaceholderType)
 
-	go executeMappers(options.ctx, mapper, source, collector, done, options.workers)
+	go executeMappers(mapperContext{
+		ctx: options.ctx,
+		mapper: func(item interface{}, writer Writer) {
+			mapper(item)
+		},
+		source:    source,
+		panicChan: panicChan,
+		collector: collector,
+		doneChan:  done,
+		workers:   options.workers,
+	})
 
-	return collector
+	for {
+		select {
+		case v := <-panicChan.channel:
+			panic(v)
+		case _, ok := <-collector:
+			if !ok {
+				return
+			}
+		}
+	}
 }
 
 // MapReduce maps all elements generated from given generate func,
 // and reduces the output elements with given reducer.
 func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc,
 	opts ...Option) (interface{}, error) {
-	source := buildSource(generate)
-	return MapReduceChan(source, mapper, reducer, opts...)
+	panicChan := &onceChan{channel: make(chan interface{})}
+	source := buildSource(generate, panicChan)
+	return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
 }
 
 // MapReduceChan maps all elements from source, and reduce the output elements with given reducer.
 func MapReduceChan(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc,
 	opts ...Option) (interface{}, error) {
+	panicChan := &onceChan{channel: make(chan interface{})}
+	return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
+}
+
+// MapReduceChan maps all elements from source, and reduce the output elements with given reducer.
+func mapReduceWithPanicChan(source <-chan interface{}, panicChan *onceChan, mapper MapperFunc,
+	reducer ReducerFunc, opts ...Option) (interface{}, error) {
 	options := buildOptions(opts...)
+	// output is used to write the final result
 	output := make(chan interface{})
 	defer func() {
+		// reducer can only write once, if more, panic
 		for range output {
 			panic("more than one element written in reducer")
 		}
 	}()
 
+	// collector is used to collect data from mapper, and consume in reducer
 	collector := make(chan interface{}, options.workers)
+	// if done is closed, all mappers and reducer should stop processing
 	done := make(chan lang.PlaceholderType)
 	writer := newGuardedWriter(options.ctx, output, done)
 	var closeOnce sync.Once
+	// use atomic.Value to avoid data race
 	var retErr errorx.AtomicError
 	finish := func() {
 		closeOnce.Do(func() {
@@ -151,30 +186,38 @@ func MapReduceChan(source <-chan interface{}, mapper MapperFunc, reducer Reducer
 	go func() {
 		defer func() {
 			drain(collector)
-
 			if r := recover(); r != nil {
-				cancel(fmt.Errorf("%v", r))
-			} else {
-				finish()
+				panicChan.write(r)
 			}
+			finish()
 		}()
 
 		reducer(collector, writer, cancel)
 	}()
 
-	go executeMappers(options.ctx, func(item interface{}, w Writer) {
-		mapper(item, w, cancel)
-	}, source, collector, done, options.workers)
+	go executeMappers(mapperContext{
+		ctx: options.ctx,
+		mapper: func(item interface{}, w Writer) {
+			mapper(item, w, cancel)
+		},
+		source:    source,
+		panicChan: panicChan,
+		collector: collector,
+		doneChan:  done,
+		workers:   options.workers,
+	})
 
 	select {
 	case <-options.ctx.Done():
 		cancel(context.DeadlineExceeded)
 		return nil, context.DeadlineExceeded
-	case value, ok := <-output:
+	case v := <-panicChan.channel:
+		panic(v)
+	case v, ok := <-output:
 		if err := retErr.Load(); err != nil {
 			return nil, err
 		} else if ok {
-			return value, nil
+			return v, nil
 		} else {
 			return nil, ErrReduceNoOutput
 		}
@@ -221,12 +264,18 @@ func buildOptions(opts ...Option) *mapReduceOptions {
 	return options
 }
 
-func buildSource(generate GenerateFunc) chan interface{} {
+func buildSource(generate GenerateFunc, panicChan *onceChan) chan interface{} {
 	source := make(chan interface{})
-	threading.GoSafe(func() {
-		defer close(source)
+	go func() {
+		defer func() {
+			if r := recover(); r != nil {
+				panicChan.write(r)
+			}
+			close(source)
+		}()
+
 		generate(source)
-	})
+	}()
 
 	return source
 }
@@ -238,39 +287,54 @@ func drain(channel <-chan interface{}) {
 	}
 }
 
-func executeMappers(ctx context.Context, mapper MapFunc, input <-chan interface{},
-	collector chan<- interface{}, done <-chan lang.PlaceholderType, workers int) {
+func executeMappers(mCtx mapperContext) {
 	var wg sync.WaitGroup
+	pc := &onceChan{channel: make(chan interface{})}
 	defer func() {
+		// in case panic happens when processing last item, for loop not handling it.
+		select {
+		case r := <-pc.channel:
+			mCtx.panicChan.write(r)
+		default:
+		}
+
 		wg.Wait()
-		close(collector)
+		close(mCtx.collector)
+		drain(mCtx.source)
 	}()
 
-	pool := make(chan lang.PlaceholderType, workers)
-	writer := newGuardedWriter(ctx, collector, done)
+	pool := make(chan lang.PlaceholderType, mCtx.workers)
+	writer := newGuardedWriter(mCtx.ctx, mCtx.collector, mCtx.doneChan)
 	for {
 		select {
-		case <-ctx.Done():
+		case <-mCtx.ctx.Done():
+			return
+		case <-mCtx.doneChan:
 			return
-		case <-done:
+		case r := <-pc.channel:
+			// make sure this method quit ASAP,
+			// without this case branch, all the items from source will be consumed.
+			mCtx.panicChan.write(r)
 			return
 		case pool <- lang.Placeholder:
-			item, ok := <-input
+			item, ok := <-mCtx.source
 			if !ok {
 				<-pool
 				return
 			}
 
 			wg.Add(1)
-			// better to safely run caller defined method
-			threading.GoSafe(func() {
+			go func() {
 				defer func() {
+					if r := recover(); r != nil {
+						pc.write(r)
+					}
 					wg.Done()
 					<-pool
 				}()
 
-				mapper(item, writer)
-			})
+				mCtx.mapper(item, writer)
+			}()
 		}
 	}
 }
@@ -316,3 +380,16 @@ func (gw guardedWriter) Write(v interface{}) {
 		gw.channel <- v
 	}
 }
+
+type onceChan struct {
+	channel chan interface{}
+	wrote   int32
+}
+
+func (oc *onceChan) write(val interface{}) {
+	if atomic.AddInt32(&oc.wrote, 1) > 1 {
+		return
+	}
+
+	oc.channel <- val
+}

+ 78 - 0
core/mr/mapreduce_fuzz_test.go

@@ -0,0 +1,78 @@
+//go:build go1.18
+// +build go1.18
+
+package mr
+
+import (
+	"fmt"
+	"math/rand"
+	"runtime"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"go.uber.org/goleak"
+)
+
+func FuzzMapReduce(f *testing.F) {
+	rand.Seed(time.Now().UnixNano())
+
+	f.Add(int64(10), runtime.NumCPU())
+	f.Fuzz(func(t *testing.T, n int64, workers int) {
+		n = n%5000 + 5000
+		genPanic := rand.Intn(100) == 0
+		mapperPanic := rand.Intn(100) == 0
+		reducerPanic := rand.Intn(100) == 0
+		genIdx := rand.Int63n(n)
+		mapperIdx := rand.Int63n(n)
+		reducerIdx := rand.Int63n(n)
+		squareSum := (n - 1) * n * (2*n - 1) / 6
+
+		fn := func() (interface{}, error) {
+			defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
+
+			return MapReduce(func(source chan<- interface{}) {
+				for i := int64(0); i < n; i++ {
+					source <- i
+					if genPanic && i == genIdx {
+						panic("foo")
+					}
+				}
+			}, func(item interface{}, writer Writer, cancel func(error)) {
+				v := item.(int64)
+				if mapperPanic && v == mapperIdx {
+					panic("bar")
+				}
+				writer.Write(v * v)
+			}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
+				var idx int64
+				var total int64
+				for v := range pipe {
+					if reducerPanic && idx == reducerIdx {
+						panic("baz")
+					}
+					total += v.(int64)
+					idx++
+				}
+				writer.Write(total)
+			}, WithWorkers(workers%50+runtime.NumCPU()))
+		}
+
+		if genPanic || mapperPanic || reducerPanic {
+			var buf strings.Builder
+			buf.WriteString(fmt.Sprintf("n: %d", n))
+			buf.WriteString(fmt.Sprintf(", genPanic: %t", genPanic))
+			buf.WriteString(fmt.Sprintf(", mapperPanic: %t", mapperPanic))
+			buf.WriteString(fmt.Sprintf(", reducerPanic: %t", reducerPanic))
+			buf.WriteString(fmt.Sprintf(", genIdx: %d", genIdx))
+			buf.WriteString(fmt.Sprintf(", mapperIdx: %d", mapperIdx))
+			buf.WriteString(fmt.Sprintf(", reducerIdx: %d", reducerIdx))
+			assert.Panicsf(t, func() { fn() }, buf.String())
+		} else {
+			val, err := fn()
+			assert.Nil(t, err)
+			assert.Equal(t, squareSum, val.(int64))
+		}
+	})
+}

+ 178 - 128
core/mr/mapreduce_test.go

@@ -11,8 +11,6 @@ import (
 	"time"
 
 	"github.com/stretchr/testify/assert"
-	"github.com/zeromicro/go-zero/core/stringx"
-	"github.com/zeromicro/go-zero/core/syncx"
 	"go.uber.org/goleak"
 )
 
@@ -124,84 +122,69 @@ func TestForEach(t *testing.T) {
 	t.Run("all", func(t *testing.T) {
 		defer goleak.VerifyNone(t)
 
-		ForEach(func(source chan<- interface{}) {
-			for i := 0; i < tasks; i++ {
-				source <- i
-			}
-		}, func(item interface{}) {
-			panic("foo")
+		assert.PanicsWithValue(t, "foo", func() {
+			ForEach(func(source chan<- interface{}) {
+				for i := 0; i < tasks; i++ {
+					source <- i
+				}
+			}, func(item interface{}) {
+				panic("foo")
+			})
 		})
 	})
 }
 
-func TestMap(t *testing.T) {
+func TestGeneratePanic(t *testing.T) {
 	defer goleak.VerifyNone(t)
 
-	tests := []struct {
-		mapper MapFunc
-		expect int
-	}{
-		{
-			mapper: func(item interface{}, writer Writer) {
-				v := item.(int)
-				writer.Write(v * v)
-			},
-			expect: 30,
-		},
-		{
-			mapper: func(item interface{}, writer Writer) {
-				v := item.(int)
-				if v%2 == 0 {
-					return
-				}
-				writer.Write(v * v)
-			},
-			expect: 10,
-		},
-		{
-			mapper: func(item interface{}, writer Writer) {
-				v := item.(int)
-				if v%2 == 0 {
-					panic(v)
-				}
-				writer.Write(v * v)
-			},
-			expect: 10,
-		},
-	}
+	t.Run("all", func(t *testing.T) {
+		assert.PanicsWithValue(t, "foo", func() {
+			ForEach(func(source chan<- interface{}) {
+				panic("foo")
+			}, func(item interface{}) {
+			})
+		})
+	})
+}
 
-	for _, test := range tests {
-		t.Run(stringx.Rand(), func(t *testing.T) {
-			channel := Map(func(source chan<- interface{}) {
-				for i := 1; i < 5; i++ {
+func TestMapperPanic(t *testing.T) {
+	defer goleak.VerifyNone(t)
+
+	const tasks = 1000
+	var run int32
+	t.Run("all", func(t *testing.T) {
+		assert.PanicsWithValue(t, "foo", func() {
+			_, _ = MapReduce(func(source chan<- interface{}) {
+				for i := 0; i < tasks; i++ {
 					source <- i
 				}
-			}, test.mapper, WithWorkers(-1))
-
-			var result int
-			for v := range channel {
-				result += v.(int)
-			}
-
-			assert.Equal(t, test.expect, result)
+			}, func(item interface{}, writer Writer, cancel func(error)) {
+				atomic.AddInt32(&run, 1)
+				panic("foo")
+			}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
+			})
 		})
-	}
+		assert.True(t, atomic.LoadInt32(&run) < tasks/2)
+	})
 }
 
 func TestMapReduce(t *testing.T) {
 	defer goleak.VerifyNone(t)
 
 	tests := []struct {
+		name        string
 		mapper      MapperFunc
 		reducer     ReducerFunc
 		expectErr   error
 		expectValue interface{}
 	}{
 		{
+			name:        "simple",
 			expectErr:   nil,
 			expectValue: 30,
 		},
 		{
+			name: "cancel with error",
 			mapper: func(item interface{}, writer Writer, cancel func(error)) {
 				v := item.(int)
 				if v%3 == 0 {
@@ -212,6 +195,7 @@ func TestMapReduce(t *testing.T) {
 			expectErr: errDummy,
 		},
 		{
+			name: "cancel with nil",
 			mapper: func(item interface{}, writer Writer, cancel func(error)) {
 				v := item.(int)
 				if v%3 == 0 {
@@ -223,6 +207,7 @@ func TestMapReduce(t *testing.T) {
 			expectValue: nil,
 		},
 		{
+			name: "cancel with more",
 			reducer: func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
 				var result int
 				for item := range pipe {
@@ -237,45 +222,68 @@ func TestMapReduce(t *testing.T) {
 		},
 	}
 
-	for _, test := range tests {
-		t.Run(stringx.Rand(), func(t *testing.T) {
-			if test.mapper == nil {
-				test.mapper = func(item interface{}, writer Writer, cancel func(error)) {
-					v := item.(int)
-					writer.Write(v * v)
-				}
-			}
-			if test.reducer == nil {
-				test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
-					var result int
-					for item := range pipe {
-						result += item.(int)
+	t.Run("MapReduce", func(t *testing.T) {
+		for _, test := range tests {
+			t.Run(test.name, func(t *testing.T) {
+				if test.mapper == nil {
+					test.mapper = func(item interface{}, writer Writer, cancel func(error)) {
+						v := item.(int)
+						writer.Write(v * v)
 					}
-					writer.Write(result)
 				}
-			}
-			value, err := MapReduce(func(source chan<- interface{}) {
-				for i := 1; i < 5; i++ {
-					source <- i
+				if test.reducer == nil {
+					test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
+						var result int
+						for item := range pipe {
+							result += item.(int)
+						}
+						writer.Write(result)
+					}
 				}
-			}, test.mapper, test.reducer, WithWorkers(runtime.NumCPU()))
+				value, err := MapReduce(func(source chan<- interface{}) {
+					for i := 1; i < 5; i++ {
+						source <- i
+					}
+				}, test.mapper, test.reducer, WithWorkers(runtime.NumCPU()))
 
-			assert.Equal(t, test.expectErr, err)
-			assert.Equal(t, test.expectValue, value)
-		})
-	}
-}
+				assert.Equal(t, test.expectErr, err)
+				assert.Equal(t, test.expectValue, value)
+			})
+		}
+	})
 
-func TestMapReducePanicBothMapperAndReducer(t *testing.T) {
-	defer goleak.VerifyNone(t)
+	t.Run("MapReduce", func(t *testing.T) {
+		for _, test := range tests {
+			t.Run(test.name, func(t *testing.T) {
+				if test.mapper == nil {
+					test.mapper = func(item interface{}, writer Writer, cancel func(error)) {
+						v := item.(int)
+						writer.Write(v * v)
+					}
+				}
+				if test.reducer == nil {
+					test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
+						var result int
+						for item := range pipe {
+							result += item.(int)
+						}
+						writer.Write(result)
+					}
+				}
 
-	_, _ = MapReduce(func(source chan<- interface{}) {
-		source <- 0
-		source <- 1
-	}, func(item interface{}, writer Writer, cancel func(error)) {
-		panic("foo")
-	}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
-		panic("bar")
+				source := make(chan interface{})
+				go func() {
+					for i := 1; i < 5; i++ {
+						source <- i
+					}
+					close(source)
+				}()
+
+				value, err := MapReduceChan(source, test.mapper, test.reducer, WithWorkers(-1))
+				assert.Equal(t, test.expectErr, err)
+				assert.Equal(t, test.expectValue, value)
+			})
+		}
 	})
 }
 
@@ -302,16 +310,19 @@ func TestMapReduceVoid(t *testing.T) {
 
 	var value uint32
 	tests := []struct {
+		name        string
 		mapper      MapperFunc
 		reducer     VoidReducerFunc
 		expectValue uint32
 		expectErr   error
 	}{
 		{
+			name:        "simple",
 			expectValue: 30,
 			expectErr:   nil,
 		},
 		{
+			name: "cancel with error",
 			mapper: func(item interface{}, writer Writer, cancel func(error)) {
 				v := item.(int)
 				if v%3 == 0 {
@@ -322,6 +333,7 @@ func TestMapReduceVoid(t *testing.T) {
 			expectErr: errDummy,
 		},
 		{
+			name: "cancel with nil",
 			mapper: func(item interface{}, writer Writer, cancel func(error)) {
 				v := item.(int)
 				if v%3 == 0 {
@@ -332,6 +344,7 @@ func TestMapReduceVoid(t *testing.T) {
 			expectErr: ErrCancelWithNil,
 		},
 		{
+			name: "cancel with more",
 			reducer: func(pipe <-chan interface{}, cancel func(error)) {
 				for item := range pipe {
 					result := atomic.AddUint32(&value, uint32(item.(int)))
@@ -345,7 +358,7 @@ func TestMapReduceVoid(t *testing.T) {
 	}
 
 	for _, test := range tests {
-		t.Run(stringx.Rand(), func(t *testing.T) {
+		t.Run(test.name, func(t *testing.T) {
 			atomic.StoreUint32(&value, 0)
 
 			if test.mapper == nil {
@@ -400,39 +413,59 @@ func TestMapReduceVoidWithDelay(t *testing.T) {
 	assert.Equal(t, 0, result[1])
 }
 
-func TestMapVoid(t *testing.T) {
+func TestMapReducePanic(t *testing.T) {
 	defer goleak.VerifyNone(t)
 
-	const tasks = 1000
-	var count uint32
-	ForEach(func(source chan<- interface{}) {
-		for i := 0; i < tasks; i++ {
-			source <- i
-		}
-	}, func(item interface{}) {
-		atomic.AddUint32(&count, 1)
+	assert.Panics(t, func() {
+		_, _ = MapReduce(func(source chan<- interface{}) {
+			source <- 0
+			source <- 1
+		}, func(item interface{}, writer Writer, cancel func(error)) {
+			i := item.(int)
+			writer.Write(i)
+		}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
+			for range pipe {
+				panic("panic")
+			}
+		})
 	})
+}
 
-	assert.Equal(t, tasks, int(count))
+func TestMapReducePanicOnce(t *testing.T) {
+	defer goleak.VerifyNone(t)
+
+	assert.Panics(t, func() {
+		_, _ = MapReduce(func(source chan<- interface{}) {
+			for i := 0; i < 100; i++ {
+				source <- i
+			}
+		}, func(item interface{}, writer Writer, cancel func(error)) {
+			i := item.(int)
+			if i == 0 {
+				panic("foo")
+			}
+			writer.Write(i)
+		}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
+			for range pipe {
+				panic("bar")
+			}
+		})
+	})
 }
 
-func TestMapReducePanic(t *testing.T) {
+func TestMapReducePanicBothMapperAndReducer(t *testing.T) {
 	defer goleak.VerifyNone(t)
 
-	v, err := MapReduce(func(source chan<- interface{}) {
-		source <- 0
-		source <- 1
-	}, func(item interface{}, writer Writer, cancel func(error)) {
-		i := item.(int)
-		writer.Write(i)
-	}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
-		for range pipe {
-			panic("panic")
-		}
+	assert.Panics(t, func() {
+		_, _ = MapReduce(func(source chan<- interface{}) {
+			source <- 0
+			source <- 1
+		}, func(item interface{}, writer Writer, cancel func(error)) {
+			panic("foo")
+		}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
+			panic("bar")
+		})
 	})
-	assert.Nil(t, v)
-	assert.NotNil(t, err)
-	assert.Equal(t, "panic", err.Error())
 }
 
 func TestMapReduceVoidCancel(t *testing.T) {
@@ -461,13 +494,13 @@ func TestMapReduceVoidCancel(t *testing.T) {
 func TestMapReduceVoidCancelWithRemains(t *testing.T) {
 	defer goleak.VerifyNone(t)
 
-	var done syncx.AtomicBool
+	var done int32
 	var result []int
 	err := MapReduceVoid(func(source chan<- interface{}) {
 		for i := 0; i < defaultWorkers*2; i++ {
 			source <- i
 		}
-		done.Set(true)
+		atomic.AddInt32(&done, 1)
 	}, func(item interface{}, writer Writer, cancel func(error)) {
 		i := item.(int)
 		if i == defaultWorkers/2 {
@@ -482,7 +515,7 @@ func TestMapReduceVoidCancelWithRemains(t *testing.T) {
 	})
 	assert.NotNil(t, err)
 	assert.Equal(t, "anything", err.Error())
-	assert.True(t, done.True())
+	assert.Equal(t, int32(1), done)
 }
 
 func TestMapReduceWithoutReducerWrite(t *testing.T) {
@@ -507,34 +540,51 @@ func TestMapReduceVoidPanicInReducer(t *testing.T) {
 	defer goleak.VerifyNone(t)
 
 	const message = "foo"
-	var done syncx.AtomicBool
-	err := MapReduceVoid(func(source chan<- interface{}) {
+	assert.Panics(t, func() {
+		var done int32
+		_ = MapReduceVoid(func(source chan<- interface{}) {
+			for i := 0; i < defaultWorkers*2; i++ {
+				source <- i
+			}
+			atomic.AddInt32(&done, 1)
+		}, func(item interface{}, writer Writer, cancel func(error)) {
+			i := item.(int)
+			writer.Write(i)
+		}, func(pipe <-chan interface{}, cancel func(error)) {
+			panic(message)
+		}, WithWorkers(1))
+	})
+}
+
+func TestForEachWithContext(t *testing.T) {
+	defer goleak.VerifyNone(t)
+
+	var done int32
+	ctx, cancel := context.WithCancel(context.Background())
+	ForEach(func(source chan<- interface{}) {
 		for i := 0; i < defaultWorkers*2; i++ {
 			source <- i
 		}
-		done.Set(true)
-	}, func(item interface{}, writer Writer, cancel func(error)) {
+		atomic.AddInt32(&done, 1)
+	}, func(item interface{}) {
 		i := item.(int)
-		writer.Write(i)
-	}, func(pipe <-chan interface{}, cancel func(error)) {
-		panic(message)
-	}, WithWorkers(1))
-	assert.NotNil(t, err)
-	assert.Equal(t, message, err.Error())
-	assert.True(t, done.True())
+		if i == defaultWorkers/2 {
+			cancel()
+		}
+	}, WithContext(ctx))
 }
 
 func TestMapReduceWithContext(t *testing.T) {
 	defer goleak.VerifyNone(t)
 
-	var done syncx.AtomicBool
+	var done int32
 	var result []int
 	ctx, cancel := context.WithCancel(context.Background())
 	err := MapReduceVoid(func(source chan<- interface{}) {
 		for i := 0; i < defaultWorkers*2; i++ {
 			source <- i
 		}
-		done.Set(true)
+		atomic.AddInt32(&done, 1)
 	}, func(item interface{}, writer Writer, c func(error)) {
 		i := item.(int)
 		if i == defaultWorkers/2 {