mapreduce_fuzz_test.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. //go:build go1.18
  2. // +build go1.18
  3. package mr
  4. import (
  5. "fmt"
  6. "math/rand"
  7. "runtime"
  8. "strings"
  9. "testing"
  10. "time"
  11. "github.com/stretchr/testify/assert"
  12. "go.uber.org/goleak"
  13. )
  14. func FuzzMapReduce(f *testing.F) {
  15. rand.Seed(time.Now().UnixNano())
  16. f.Add(uint(10), uint(runtime.NumCPU()))
  17. f.Fuzz(func(t *testing.T, num, workers uint) {
  18. n := int64(num)%5000 + 5000
  19. genPanic := rand.Intn(100) == 0
  20. mapperPanic := rand.Intn(100) == 0
  21. reducerPanic := rand.Intn(100) == 0
  22. genIdx := rand.Int63n(n)
  23. mapperIdx := rand.Int63n(n)
  24. reducerIdx := rand.Int63n(n)
  25. squareSum := (n - 1) * n * (2*n - 1) / 6
  26. fn := func() (any, error) {
  27. defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
  28. return MapReduce(func(source chan<- any) {
  29. for i := int64(0); i < n; i++ {
  30. source <- i
  31. if genPanic && i == genIdx {
  32. panic("foo")
  33. }
  34. }
  35. }, func(item any, writer Writer, cancel func(error)) {
  36. v := item.(int64)
  37. if mapperPanic && v == mapperIdx {
  38. panic("bar")
  39. }
  40. writer.Write(v * v)
  41. }, func(pipe <-chan any, writer Writer, cancel func(error)) {
  42. var idx int64
  43. var total int64
  44. for v := range pipe {
  45. if reducerPanic && idx == reducerIdx {
  46. panic("baz")
  47. }
  48. total += v.(int64)
  49. idx++
  50. }
  51. writer.Write(total)
  52. }, WithWorkers(int(workers)%50+runtime.NumCPU()/2))
  53. }
  54. if genPanic || mapperPanic || reducerPanic {
  55. var buf strings.Builder
  56. buf.WriteString(fmt.Sprintf("n: %d", n))
  57. buf.WriteString(fmt.Sprintf(", genPanic: %t", genPanic))
  58. buf.WriteString(fmt.Sprintf(", mapperPanic: %t", mapperPanic))
  59. buf.WriteString(fmt.Sprintf(", reducerPanic: %t", reducerPanic))
  60. buf.WriteString(fmt.Sprintf(", genIdx: %d", genIdx))
  61. buf.WriteString(fmt.Sprintf(", mapperIdx: %d", mapperIdx))
  62. buf.WriteString(fmt.Sprintf(", reducerIdx: %d", reducerIdx))
  63. assert.Panicsf(t, func() { fn() }, buf.String())
  64. } else {
  65. val, err := fn()
  66. assert.Nil(t, err)
  67. assert.Equal(t, squareSum, val.(int64))
  68. }
  69. })
  70. }