Explorar o código

fix mapreduce problem when reducer doesn't write

kevin %!s(int64=4) %!d(string=hai) anos
pai
achega
e987eb60d3
Modificáronse 2 ficheiros con 55 adicións e 4 borrados
  1. 15 4
      core/mr/mapreduce.go
  2. 40 0
      example/mapreduce/deadlock/main.go

+ 15 - 4
core/mr/mapreduce.go

@@ -16,7 +16,10 @@ const (
 	minWorkers     = 1
 )
 
-var ErrCancelWithNil = errors.New("mapreduce cancelled with nil")
+var (
+	ErrCancelWithNil  = errors.New("mapreduce cancelled with nil")
+	ErrReduceNoOutput = errors.New("reduce not writing value")
+)
 
 type (
 	GenerateFunc    func(source chan<- interface{})
@@ -93,7 +96,14 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
 	collector := make(chan interface{}, options.workers)
 	done := syncx.NewDoneChan()
 	writer := newGuardedWriter(output, done.Done())
+	var closeOnce sync.Once
 	var retErr errorx.AtomicError
+	finish := func() {
+		closeOnce.Do(func() {
+			done.Close()
+			close(output)
+		})
+	}
 	cancel := once(func(err error) {
 		if err != nil {
 			retErr.Set(err)
@@ -102,14 +112,15 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
 		}
 
 		drain(source)
-		done.Close()
-		close(output)
+		finish()
 	})
 
 	go func() {
 		defer func() {
 			if r := recover(); r != nil {
 				cancel(fmt.Errorf("%v", r))
+			} else {
+				finish()
 			}
 		}()
 		reducer(collector, writer, cancel)
@@ -122,7 +133,7 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
 	} else if ok {
 		return value, nil
 	} else {
-		return nil, nil
+		return nil, ErrReduceNoOutput
 	}
 }
 

+ 40 - 0
example/mapreduce/deadlock/main.go

@@ -0,0 +1,40 @@
+package main
+
+import (
+	"log"
+	"strconv"
+
+	"github.com/tal-tech/go-zero/core/mr"
+)
+
+type User struct {
+	Uid  int
+	Name string
+}
+
+func main() {
+	uids := []int{111, 222, 333}
+	res, err := mr.MapReduce(func(source chan<- interface{}) {
+		for _, uid := range uids {
+			source <- uid
+		}
+	}, func(item interface{}, writer mr.Writer, cancel func(error)) {
+		uid := item.(int)
+		user := &User{
+			Uid:  uid,
+			Name: strconv.Itoa(uid),
+		}
+		writer.Write(user)
+	}, func(pipe <-chan interface{}, writer mr.Writer, cancel func(error)) {
+		var users []*User
+		for p := range pipe {
+			users = append(users, p.(*User))
+		}
+		// missing writer.Write(...), should not panic
+	})
+	if err != nil {
+		log.Print(err)
+		return
+	}
+	log.Print(len(res.([]*User)))
+}