소스 검색

fix: goroutine stuck on edge case (#1495)

* fix: goroutine stuck on edge case

* refactor: simplify mapreduce implementation
Kevin Wan 3 년 전
부모
커밋
6c2abe7474
3개의 변경된 파일115개의 추가작업 그리고 19개의 파일을 삭제
  1. 4 15
      core/mr/mapreduce.go
  2. 4 4
      core/mr/mapreduce_fuzz_test.go
  3. 107 0
      core/mr/mapreduce_rand_test.go

+ 4 - 15
core/mr/mapreduce.go

@@ -289,33 +289,21 @@ func drain(channel <-chan interface{}) {
 
 func executeMappers(mCtx mapperContext) {
 	var wg sync.WaitGroup
-	pc := &onceChan{channel: make(chan interface{})}
 	defer func() {
-		// in case panic happens when processing last item, for loop not handling it.
-		select {
-		case r := <-pc.channel:
-			mCtx.panicChan.write(r)
-		default:
-		}
-
 		wg.Wait()
 		close(mCtx.collector)
 		drain(mCtx.source)
 	}()
 
+	var failed int32
 	pool := make(chan lang.PlaceholderType, mCtx.workers)
 	writer := newGuardedWriter(mCtx.ctx, mCtx.collector, mCtx.doneChan)
-	for {
+	for atomic.LoadInt32(&failed) == 0 {
 		select {
 		case <-mCtx.ctx.Done():
 			return
 		case <-mCtx.doneChan:
 			return
-		case r := <-pc.channel:
-			// make sure this method quit ASAP,
-			// without this case branch, all the items from source will be consumed.
-			mCtx.panicChan.write(r)
-			return
 		case pool <- lang.Placeholder:
 			item, ok := <-mCtx.source
 			if !ok {
@@ -327,7 +315,8 @@ func executeMappers(mCtx mapperContext) {
 			go func() {
 				defer func() {
 					if r := recover(); r != nil {
-						pc.write(r)
+						atomic.AddInt32(&failed, 1)
+						mCtx.panicChan.write(r)
 					}
 					wg.Done()
 					<-pool

+ 4 - 4
core/mr/mapreduce_fuzz_test.go

@@ -18,9 +18,9 @@ import (
 func FuzzMapReduce(f *testing.F) {
 	rand.Seed(time.Now().UnixNano())
 
-	f.Add(int64(10), runtime.NumCPU())
-	f.Fuzz(func(t *testing.T, n int64, workers int) {
-		n = n%5000 + 5000
+	f.Add(uint(10), uint(runtime.NumCPU()))
+	f.Fuzz(func(t *testing.T, num uint, workers uint) {
+		n := int64(num)%5000 + 5000
 		genPanic := rand.Intn(100) == 0
 		mapperPanic := rand.Intn(100) == 0
 		reducerPanic := rand.Intn(100) == 0
@@ -56,7 +56,7 @@ func FuzzMapReduce(f *testing.F) {
 					idx++
 				}
 				writer.Write(total)
-			}, WithWorkers(workers%50+runtime.NumCPU()))
+			}, WithWorkers(int(workers)%50+runtime.NumCPU()/2))
 		}
 
 		if genPanic || mapperPanic || reducerPanic {

+ 107 - 0
core/mr/mapreduce_rand_test.go

@@ -0,0 +1,107 @@
+//go:build fuzz
+// +build fuzz
+
+package mr
+
+import (
+	"fmt"
+	"math/rand"
+	"runtime"
+	"strconv"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/zeromicro/go-zero/core/threading"
+	"gopkg.in/cheggaaa/pb.v1"
+)
+
+// If Fuzz stuck, we don't know why, because it only returns hung or unexpected,
+// so we need to simulate the fuzz test in test mode.
+func TestMapReduceRandom(t *testing.T) {
+	rand.Seed(time.Now().UnixNano())
+
+	const (
+		times  = 10000
+		nRange = 500
+		mega   = 1024 * 1024
+	)
+
+	bar := pb.New(times).Start()
+	runner := threading.NewTaskRunner(runtime.NumCPU())
+	var wg sync.WaitGroup
+	wg.Add(times)
+	for i := 0; i < times; i++ {
+		runner.Schedule(func() {
+			start := time.Now()
+			defer func() {
+				if time.Since(start) > time.Minute {
+					t.Fatal("timeout")
+				}
+				wg.Done()
+			}()
+
+			t.Run(strconv.Itoa(i), func(t *testing.T) {
+				n := rand.Int63n(nRange)%nRange + nRange
+				workers := rand.Int()%50 + runtime.NumCPU()/2
+				genPanic := rand.Intn(100) == 0
+				mapperPanic := rand.Intn(100) == 0
+				reducerPanic := rand.Intn(100) == 0
+				genIdx := rand.Int63n(n)
+				mapperIdx := rand.Int63n(n)
+				reducerIdx := rand.Int63n(n)
+				squareSum := (n - 1) * n * (2*n - 1) / 6
+
+				fn := func() (interface{}, error) {
+					return MapReduce(func(source chan<- interface{}) {
+						for i := int64(0); i < n; i++ {
+							source <- i
+							if genPanic && i == genIdx {
+								panic("foo")
+							}
+						}
+					}, func(item interface{}, writer Writer, cancel func(error)) {
+						v := item.(int64)
+						if mapperPanic && v == mapperIdx {
+							panic("bar")
+						}
+						writer.Write(v * v)
+					}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
+						var idx int64
+						var total int64
+						for v := range pipe {
+							if reducerPanic && idx == reducerIdx {
+								panic("baz")
+							}
+							total += v.(int64)
+							idx++
+						}
+						writer.Write(total)
+					}, WithWorkers(int(workers)%50+runtime.NumCPU()/2))
+				}
+
+				if genPanic || mapperPanic || reducerPanic {
+					var buf strings.Builder
+					buf.WriteString(fmt.Sprintf("n: %d", n))
+					buf.WriteString(fmt.Sprintf(", genPanic: %t", genPanic))
+					buf.WriteString(fmt.Sprintf(", mapperPanic: %t", mapperPanic))
+					buf.WriteString(fmt.Sprintf(", reducerPanic: %t", reducerPanic))
+					buf.WriteString(fmt.Sprintf(", genIdx: %d", genIdx))
+					buf.WriteString(fmt.Sprintf(", mapperIdx: %d", mapperIdx))
+					buf.WriteString(fmt.Sprintf(", reducerIdx: %d", reducerIdx))
+					assert.Panicsf(t, func() { fn() }, buf.String())
+				} else {
+					val, err := fn()
+					assert.Nil(t, err)
+					assert.Equal(t, squareSum, val.(int64))
+				}
+				bar.Increment()
+			})
+		})
+	}
+
+	wg.Wait()
+	bar.Finish()
+}