Browse Source

remove unnecessary drain, fix data race (#1435)

* remove unnecessary drain, fix data race

* chore: fix parameter order

* refactor: rename MapVoid to ForEach in mr
Kevin Wan 3 năm trước cách đây
mục cha
commit
8d6d37f71e
3 tập tin đã thay đổi với 67 bổ sung19 xóa
  1. 17 17
      core/mr/mapreduce.go
  2. 49 1
      core/mr/mapreduce_test.go
  3. 1 1
      go.mod

+ 17 - 17
core/mr/mapreduce.go

@@ -24,12 +24,12 @@ var (
 )
 
 type (
+	// ForEachFunc is used to do element processing, but no output.
+	ForEachFunc func(item interface{})
 	// GenerateFunc is used to let callers send elements into source.
 	GenerateFunc func(source chan<- interface{})
 	// MapFunc is used to do element processing and write the output to writer.
 	MapFunc func(item interface{}, writer Writer)
-	// VoidMapFunc is used to do element processing, but no output.
-	VoidMapFunc func(item interface{})
 	// MapperFunc is used to do element processing and write the output to writer,
 	// use cancel func to cancel the processing.
 	MapperFunc func(item interface{}, writer Writer, cancel func(error))
@@ -69,7 +69,6 @@ func Finish(fns ...func() error) error {
 			cancel(err)
 		}
 	}, func(pipe <-chan interface{}, cancel func(error)) {
-		drain(pipe)
 	}, WithWorkers(len(fns)))
 }
 
@@ -79,7 +78,7 @@ func FinishVoid(fns ...func()) {
 		return
 	}
 
-	MapVoid(func(source chan<- interface{}) {
+	ForEach(func(source chan<- interface{}) {
 		for _, fn := range fns {
 			source <- fn
 		}
@@ -89,6 +88,13 @@ func FinishVoid(fns ...func()) {
 	}, WithWorkers(len(fns)))
 }
 
+// ForEach maps all elements from given generate but no output.
+func ForEach(generate GenerateFunc, mapper ForEachFunc, opts ...Option) {
+	drain(Map(generate, func(item interface{}, writer Writer) {
+		mapper(item)
+	}, opts...))
+}
+
 // Map maps all elements generated from given generate func, and returns an output channel.
 func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} {
 	options := buildOptions(opts...)
@@ -106,11 +112,11 @@ func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{}
 func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc,
 	opts ...Option) (interface{}, error) {
 	source := buildSource(generate)
-	return MapReduceWithSource(source, mapper, reducer, opts...)
+	return MapReduceChan(source, mapper, reducer, opts...)
 }
 
-// MapReduceWithSource maps all elements from source, and reduce the output elements with given reducer.
-func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc,
+// MapReduceChan maps all elements from source, and reduce the output elements with given reducer.
+func MapReduceChan(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc,
 	opts ...Option) (interface{}, error) {
 	options := buildOptions(opts...)
 	output := make(chan interface{})
@@ -180,18 +186,12 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
 func MapReduceVoid(generate GenerateFunc, mapper MapperFunc, reducer VoidReducerFunc, opts ...Option) error {
 	_, err := MapReduce(generate, mapper, func(input <-chan interface{}, writer Writer, cancel func(error)) {
 		reducer(input, cancel)
-		// We need to write a placeholder to let MapReduce to continue on reducer done,
-		// otherwise, all goroutines are waiting. The placeholder will be discarded by MapReduce.
-		writer.Write(lang.Placeholder)
 	}, opts...)
-	return err
-}
+	if errors.Is(err, ErrReduceNoOutput) {
+		return nil
+	}
 
-// MapVoid maps all elements from given generate but no output.
-func MapVoid(generate GenerateFunc, mapper VoidMapFunc, opts ...Option) {
-	drain(Map(generate, func(item interface{}, writer Writer) {
-		mapper(item)
-	}, opts...))
+	return err
 }
 
 // WithContext customizes a mapreduce processing accepts a given ctx.

+ 49 - 1
core/mr/mapreduce_test.go

@@ -86,6 +86,54 @@ func TestFinishVoid(t *testing.T) {
 	assert.Equal(t, uint32(10), atomic.LoadUint32(&total))
 }
 
+func TestForEach(t *testing.T) {
+	const tasks = 1000
+
+	t.Run("all", func(t *testing.T) {
+		defer goleak.VerifyNone(t)
+
+		var count uint32
+		ForEach(func(source chan<- interface{}) {
+			for i := 0; i < tasks; i++ {
+				source <- i
+			}
+		}, func(item interface{}) {
+			atomic.AddUint32(&count, 1)
+		}, WithWorkers(-1))
+
+		assert.Equal(t, tasks, int(count))
+	})
+
+	t.Run("odd", func(t *testing.T) {
+		defer goleak.VerifyNone(t)
+
+		var count uint32
+		ForEach(func(source chan<- interface{}) {
+			for i := 0; i < tasks; i++ {
+				source <- i
+			}
+		}, func(item interface{}) {
+			if item.(int)%2 == 0 {
+				atomic.AddUint32(&count, 1)
+			}
+		})
+
+		assert.Equal(t, tasks/2, int(count))
+	})
+
+	t.Run("all", func(t *testing.T) {
+		defer goleak.VerifyNone(t)
+
+		ForEach(func(source chan<- interface{}) {
+			for i := 0; i < tasks; i++ {
+				source <- i
+			}
+		}, func(item interface{}) {
+			panic("foo")
+		})
+	})
+}
+
 func TestMap(t *testing.T) {
 	defer goleak.VerifyNone(t)
 
@@ -344,7 +392,7 @@ func TestMapVoid(t *testing.T) {
 
 	const tasks = 1000
 	var count uint32
-	MapVoid(func(source chan<- interface{}) {
+	ForEach(func(source chan<- interface{}) {
 		for i := 0; i < tasks; i++ {
 			source <- i
 		}

+ 1 - 1
go.mod

@@ -29,7 +29,7 @@ require (
 	go.opentelemetry.io/otel/sdk v1.1.0
 	go.opentelemetry.io/otel/trace v1.1.0
 	go.uber.org/automaxprocs v1.4.0
-	go.uber.org/goleak v1.1.12 // indirect
+	go.uber.org/goleak v1.1.12
 	golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f // indirect
 	golang.org/x/sys v0.0.0-20211106132015-ebca88c72f68 // indirect
 	golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac