Forráskód Böngészése

chore: optimize string search with Aho–Corasick algorithm (#1476)

* chore: optimize string search with Aho–Corasick algorithm

* chore: optimize keywords replacer

* fix: replacer bugs

* chore: reorder members
Kevin Wan 3 éve
szülő
commit
f1102fb262

+ 73 - 4
core/stringx/node.go

@@ -2,6 +2,8 @@ package stringx
 
 
 type node struct {
 type node struct {
 	children map[rune]*node
 	children map[rune]*node
+	fail     *node
+	depth    int
 	end      bool
 	end      bool
 }
 }
 
 
@@ -12,17 +14,19 @@ func (n *node) add(word string) {
 	}
 	}
 
 
 	nd := n
 	nd := n
-	for _, char := range chars {
+	var depth int
+	for i, char := range chars {
 		if nd.children == nil {
 		if nd.children == nil {
 			child := new(node)
 			child := new(node)
-			nd.children = map[rune]*node{
-				char: child,
-			}
+			child.depth = i + 1
+			nd.children = map[rune]*node{char: child}
 			nd = child
 			nd = child
 		} else if child, ok := nd.children[char]; ok {
 		} else if child, ok := nd.children[char]; ok {
 			nd = child
 			nd = child
+			depth++
 		} else {
 		} else {
 			child := new(node)
 			child := new(node)
+			child.depth = i + 1
 			nd.children[char] = child
 			nd.children[char] = child
 			nd = child
 			nd = child
 		}
 		}
@@ -30,3 +34,68 @@ func (n *node) add(word string) {
 
 
 	nd.end = true
 	nd.end = true
 }
 }
+
+func (n *node) build() {
+	n.fail = n
+	for _, child := range n.children {
+		child.fail = n
+		n.buildNode(child)
+	}
+}
+
+func (n *node) buildNode(nd *node) {
+	if nd.children == nil {
+		return
+	}
+
+	var fifo []*node
+	for key, child := range nd.children {
+		fifo = append(fifo, child)
+
+		if fail, ok := nd.fail.children[key]; ok {
+			child.fail = fail
+		} else {
+			child.fail = n
+		}
+	}
+
+	for _, val := range fifo {
+		n.buildNode(val)
+	}
+}
+
+func (n *node) find(chars []rune) []scope {
+	var scopes []scope
+	size := len(chars)
+	cur := n
+
+	for i := 0; i < size; i++ {
+		child, ok := cur.children[chars[i]]
+		if ok {
+			cur = child
+		} else if cur == n {
+			continue
+		} else {
+			cur = cur.fail
+			if child, ok = cur.children[chars[i]]; !ok {
+				continue
+			}
+			cur = child
+		}
+
+		if child.end {
+			scopes = append(scopes, scope{
+				start: i + 1 - child.depth,
+				stop:  i + 1,
+			})
+		}
+		if child.fail != n && child.fail.end {
+			scopes = append(scopes, scope{
+				start: i + 1 - child.fail.depth,
+				stop:  i + 1,
+			})
+		}
+	}
+
+	return scopes
+}

+ 25 - 0
core/stringx/node_test.go

@@ -0,0 +1,25 @@
+package stringx
+
+import "testing"
+
+func BenchmarkNodeFind(b *testing.B) {
+	b.ReportAllocs()
+
+	keywords := []string{
+		"A",
+		"AV",
+		"AV演员",
+		"无名氏",
+		"AV演员色情",
+		"日本AV女优",
+	}
+	trie := new(node)
+	for _, keyword := range keywords {
+		trie.add(keyword)
+	}
+	trie.build()
+
+	for i := 0; i < b.N; i++ {
+		trie.find([]rune("日本AV演员兼电视、电影演员。无名氏AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演"))
+	}
+}

+ 53 - 30
core/stringx/replacer.go

@@ -9,7 +9,7 @@ type (
 	}
 	}
 
 
 	replacer struct {
 	replacer struct {
-		node
+		*node
 		mapping map[string]string
 		mapping map[string]string
 	}
 	}
 )
 )
@@ -17,58 +17,81 @@ type (
 // NewReplacer returns a Replacer.
 // NewReplacer returns a Replacer.
 func NewReplacer(mapping map[string]string) Replacer {
 func NewReplacer(mapping map[string]string) Replacer {
 	rep := &replacer{
 	rep := &replacer{
+		node:    new(node),
 		mapping: mapping,
 		mapping: mapping,
 	}
 	}
 	for k := range mapping {
 	for k := range mapping {
 		rep.add(k)
 		rep.add(k)
 	}
 	}
+	rep.build()
 
 
 	return rep
 	return rep
 }
 }
 
 
+// Replace replaces text with given substitutes.
 func (r *replacer) Replace(text string) string {
 func (r *replacer) Replace(text string) string {
 	var builder strings.Builder
 	var builder strings.Builder
+	var start int
 	chars := []rune(text)
 	chars := []rune(text)
 	size := len(chars)
 	size := len(chars)
-	start := -1
 
 
-	for i := 0; i < size; i++ {
-		child, ok := r.children[chars[i]]
-		if !ok {
-			builder.WriteRune(chars[i])
-			continue
-		}
+	for start < size {
+		cur := r.node
 
 
-		if start < 0 {
-			start = i
-		}
-		end := -1
-		if child.end {
-			end = i + 1
+		if start > 0 {
+			builder.WriteString(string(chars[:start]))
 		}
 		}
 
 
-		j := i + 1
-		for ; j < size; j++ {
-			grandchild, ok := child.children[chars[j]]
-			if !ok {
-				break
+		for i := start; i < size; i++ {
+			child, ok := cur.children[chars[i]]
+			if ok {
+				cur = child
+			} else if cur == r.node {
+				builder.WriteRune(chars[i])
+				// cur already points to root, set start only
+				start = i + 1
+				continue
+			} else {
+				curDepth := cur.depth
+				cur = cur.fail
+				child, ok = cur.children[chars[i]]
+				if !ok {
+					// write this path
+					builder.WriteString(string(chars[i-curDepth : i+1]))
+					// go to root
+					cur = r.node
+					start = i + 1
+					continue
+				}
+
+				failDepth := cur.depth
+				// write path before jump
+				builder.WriteString(string(chars[start : start+curDepth-failDepth]))
+				start += curDepth - failDepth
+				cur = child
 			}
 			}
 
 
-			child = grandchild
-			if child.end {
-				end = j + 1
-				i = j
+			if cur.end {
+				val := string(chars[i+1-cur.depth : i+1])
+				builder.WriteString(r.mapping[val])
+				builder.WriteString(string(chars[i+1:]))
+				// only matching this path, all previous paths are done
+				if start >= i+1-cur.depth && i+1 >= size {
+					return builder.String()
+				}
+
+				chars = []rune(builder.String())
+				size = len(chars)
+				builder.Reset()
+				break
 			}
 			}
 		}
 		}
 
 
-		if end > 0 {
-			i = j - 1
-			builder.WriteString(r.mapping[string(chars[start:end])])
-		} else {
-			builder.WriteRune(chars[i])
+		if !cur.end {
+			builder.WriteString(string(chars[start:]))
+			return builder.String()
 		}
 		}
-		start = -1
 	}
 	}
 
 
-	return builder.String()
+	return string(chars)
 }
 }

+ 42 - 0
core/stringx/replacer_fuzz_test.go

@@ -0,0 +1,42 @@
+//go:build go1.18
+// +build go1.18
+
+package stringx
+
+import (
+	"fmt"
+	"math/rand"
+	"strings"
+	"testing"
+)
+
+func FuzzReplacerReplace(f *testing.F) {
+	keywords := make(map[string]string)
+	for i := 0; i < 20; i++ {
+		keywords[Randn(rand.Intn(10)+5)] = Randn(rand.Intn(5) + 1)
+	}
+	rep := NewReplacer(keywords)
+	printableKeywords := func() string {
+		var buf strings.Builder
+		for k, v := range keywords {
+			fmt.Fprintf(&buf, "%q: %q,\n", k, v)
+		}
+		return buf.String()
+	}
+
+	f.Add(50)
+	f.Fuzz(func(t *testing.T, n int) {
+		text := Randn(rand.Intn(n%50+50) + 1)
+		defer func() {
+			if r := recover(); r != nil {
+				t.Errorf("mapping: %s\ntext: %s", printableKeywords(), text)
+			}
+		}()
+		val := rep.Replace(text)
+		keys := rep.(*replacer).node.find([]rune(val))
+		if len(keys) > 0 {
+			t.Errorf("mapping: %s\ntext: %s\nresult: %s\nmatch: %v",
+				printableKeywords(), text, val, keys)
+		}
+	})
+}

+ 104 - 0
core/stringx/replacer_test.go

@@ -15,6 +15,14 @@ func TestReplacer_Replace(t *testing.T) {
 	assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五"))
 	assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五"))
 }
 }
 
 
+func TestReplacer_ReplaceOverlap(t *testing.T) {
+	mapping := map[string]string{
+		"3d": "34",
+		"bc": "23",
+	}
+	assert.Equal(t, "a234e", NewReplacer(mapping).Replace("abcde"))
+}
+
 func TestReplacer_ReplaceSingleChar(t *testing.T) {
 func TestReplacer_ReplaceSingleChar(t *testing.T) {
 	mapping := map[string]string{
 	mapping := map[string]string{
 		"二": "2",
 		"二": "2",
@@ -42,3 +50,99 @@ func TestReplacer_ReplaceMultiMatches(t *testing.T) {
 	}
 	}
 	assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))
 	assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))
 }
 }
+
+func TestReplacer_ReplaceJumpToFail(t *testing.T) {
+	mapping := map[string]string{
+		"bcdf": "1235",
+		"cde":  "234",
+	}
+	assert.Equal(t, "ab234fg", NewReplacer(mapping).Replace("abcdefg"))
+}
+
+func TestReplacer_ReplaceJumpToFailDup(t *testing.T) {
+	mapping := map[string]string{
+		"bcdf": "1235",
+		"ccde": "2234",
+	}
+	assert.Equal(t, "ab2234fg", NewReplacer(mapping).Replace("abccdefg"))
+}
+
+func TestReplacer_ReplaceJumpToFailEnding(t *testing.T) {
+	mapping := map[string]string{
+		"bcdf": "1235",
+		"cdef": "2345",
+	}
+	assert.Equal(t, "ab2345", NewReplacer(mapping).Replace("abcdef"))
+}
+
+func TestReplacer_ReplaceEmpty(t *testing.T) {
+	mapping := map[string]string{
+		"bcdf": "1235",
+		"cdef": "2345",
+	}
+	assert.Equal(t, "", NewReplacer(mapping).Replace(""))
+}
+
+func TestFuzzCase1(t *testing.T) {
+	keywords := map[string]string{
+		"yQyJykiqoh":     "xw",
+		"tgN70z":         "Q2P",
+		"tXKhEn":         "w1G8",
+		"5nfOW1XZO":      "GN",
+		"f4Ov9i9nHD":     "cT",
+		"1ov9Q":          "Y",
+		"7IrC9n":         "400i",
+		"JQLxonpHkOjv":   "XI",
+		"DyHQ3c7":        "Ygxux",
+		"ffyqJi":         "u",
+		"UHuvXrbD8pni":   "dN",
+		"LIDzNbUlTX":     "g",
+		"yN9WZh2rkc8Q":   "3U",
+		"Vhk11rz8CObceC": "jf",
+		"R0Rt4H2qChUQf":  "7U5M",
+		"MGQzzPCVKjV9":   "yYz",
+		"B5jUUl0u1XOY":   "l4PZ",
+		"pdvp2qfLgG8X":   "BM562",
+		"ZKl9qdApXJ2":    "T",
+		"37jnugkSevU66":  "aOHFX",
+	}
+	rep := NewReplacer(keywords)
+	text := "yjF8fyqJiiqrczOCVyoYbLvrMpnkj"
+	val := rep.Replace(text)
+	keys := rep.(*replacer).node.find([]rune(val))
+	if len(keys) > 0 {
+		t.Errorf("result: %s, match: %v", val, keys)
+	}
+}
+
+func TestFuzzCase2(t *testing.T) {
+	keywords := map[string]string{
+		"dmv2SGZvq9Yz":   "TE",
+		"rCL5DRI9uFP8":   "hvsc8",
+		"7pSA2jaomgg":    "v",
+		"kWSQvjVOIAxR":   "Oje",
+		"hgU5bYYkD3r6":   "qCXu",
+		"0eh6uI":         "MMlt",
+		"3USZSl85EKeMzw": "Pc",
+		"JONmQSuXa":      "dX",
+		"EO1WIF":         "G",
+		"uUmFJGVmacjF":   "1N",
+		"DHpw7":          "M",
+		"NYB2bm":         "CPya",
+		"9FiNvBAHHNku5":  "7FlDE",
+		"tJi3I4WxcY":     "q5",
+		"sNJ8Z1ToBV0O":   "tl",
+		"0iOg72QcPo":     "RP",
+		"pSEqeL":         "5KZ",
+		"GOyYqTgmvQ":     "9",
+		"Qv4qCsj":        "nl52E",
+		"wNQ5tOutYu5s8":  "6iGa",
+	}
+	rep := NewReplacer(keywords)
+	text := "AoRxrdKWsGhFpXwVqMLWRL74OukwjBuBh0g7pSrk"
+	val := rep.Replace(text)
+	keys := rep.(*replacer).node.find([]rune(val))
+	if len(keys) > 0 {
+		t.Errorf("result: %s, match: %v", val, keys)
+	}
+}

+ 4 - 44
core/stringx/trie.go

@@ -39,6 +39,8 @@ func NewTrie(words []string, opts ...TrieOption) Trie {
 		n.add(word)
 		n.add(word)
 	}
 	}
 
 
+	n.build()
+
 	return n
 	return n
 }
 }
 
 
@@ -48,7 +50,7 @@ func (n *trieNode) Filter(text string) (sentence string, keywords []string, foun
 		return text, nil, false
 		return text, nil, false
 	}
 	}
 
 
-	scopes := n.findKeywordScopes(chars)
+	scopes := n.find(chars)
 	keywords = n.collectKeywords(chars, scopes)
 	keywords = n.collectKeywords(chars, scopes)
 
 
 	for _, match := range scopes {
 	for _, match := range scopes {
@@ -65,7 +67,7 @@ func (n *trieNode) FindKeywords(text string) []string {
 		return nil
 		return nil
 	}
 	}
 
 
-	scopes := n.findKeywordScopes(chars)
+	scopes := n.find(chars)
 	return n.collectKeywords(chars, scopes)
 	return n.collectKeywords(chars, scopes)
 }
 }
 
 
@@ -85,48 +87,6 @@ func (n *trieNode) collectKeywords(chars []rune, scopes []scope) []string {
 	return keywords
 	return keywords
 }
 }
 
 
-func (n *trieNode) findKeywordScopes(chars []rune) []scope {
-	var scopes []scope
-	size := len(chars)
-	start := -1
-
-	for i := 0; i < size; i++ {
-		child, ok := n.children[chars[i]]
-		if !ok {
-			continue
-		}
-
-		if start < 0 {
-			start = i
-		}
-		if child.end {
-			scopes = append(scopes, scope{
-				start: start,
-				stop:  i + 1,
-			})
-		}
-
-		for j := i + 1; j < size; j++ {
-			grandchild, ok := child.children[chars[j]]
-			if !ok {
-				break
-			}
-
-			child = grandchild
-			if child.end {
-				scopes = append(scopes, scope{
-					start: start,
-					stop:  j + 1,
-				})
-			}
-		}
-
-		start = -1
-	}
-
-	return scopes
-}
-
 func (n *trieNode) replaceWithAsterisk(chars []rune, start, stop int) {
 func (n *trieNode) replaceWithAsterisk(chars []rune, start, stop int) {
 	for i := start; i < stop; i++ {
 	for i := start; i < stop; i++ {
 		chars[i] = n.mask
 		chars[i] = n.mask

+ 14 - 20
core/stringx/trie_test.go

@@ -6,6 +6,17 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
+func TestTrieSimple(t *testing.T) {
+	trie := NewTrie([]string{
+		"bc",
+		"cd",
+	})
+	output, keywords, found := trie.Filter("abcd")
+	assert.True(t, found)
+	assert.Equal(t, "a***", output)
+	assert.ElementsMatch(t, []string{"bc", "cd"}, keywords)
+}
+
 func TestTrie(t *testing.T) {
 func TestTrie(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		input    string
 		input    string
@@ -14,11 +25,11 @@ func TestTrie(t *testing.T) {
 		found    bool
 		found    bool
 	}{
 	}{
 		{
 		{
-			input:  "日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演",
+			input:  "日本AV演员兼电视、电影演员。无名氏AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演",
 			output: "日本****兼电视、电影演员。*****女优是xx出道, ******们最精彩的表演是******表演",
 			output: "日本****兼电视、电影演员。*****女优是xx出道, ******们最精彩的表演是******表演",
 			keywords: []string{
 			keywords: []string{
 				"AV演员",
 				"AV演员",
-				"苍井空",
+				"无名氏",
 				"AV",
 				"AV",
 				"日本AV女优",
 				"日本AV女优",
 				"AV演员色情",
 				"AV演员色情",
@@ -89,7 +100,7 @@ func TestTrie(t *testing.T) {
 		"一不",
 		"一不",
 		"AV",
 		"AV",
 		"AV演员",
 		"AV演员",
-		"苍井空",
+		"无名氏",
 		"AV演员色情",
 		"AV演员色情",
 		"日本AV女优",
 		"日本AV女优",
 	})
 	})
@@ -145,20 +156,3 @@ func TestTrieNested(t *testing.T) {
 	assert.True(t, ok)
 	assert.True(t, ok)
 	assert.Equal(t, "零########九十", output)
 	assert.Equal(t, "零########九十", output)
 }
 }
-
-func BenchmarkTrie(b *testing.B) {
-	b.ReportAllocs()
-
-	trie := NewTrie([]string{
-		"A",
-		"AV",
-		"AV演员",
-		"苍井空",
-		"AV演员色情",
-		"日本AV女优",
-	})
-
-	for i := 0; i < b.N; i++ {
-		trie.Filter("日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演")
-	}
-}

+ 0 - 10
core/syncx/singleflight.go

@@ -3,10 +3,6 @@ package syncx
 import "sync"
 import "sync"
 
 
 type (
 type (
-	// SharedCalls is an alias of SingleFlight.
-	// Deprecated: use SingleFlight.
-	SharedCalls = SingleFlight
-
 	// SingleFlight lets the concurrent calls with the same key to share the call result.
 	// SingleFlight lets the concurrent calls with the same key to share the call result.
 	// For example, A called F, before it's done, B called F. Then B would not execute F,
 	// For example, A called F, before it's done, B called F. Then B would not execute F,
 	// and shared the result returned by F which called by A.
 	// and shared the result returned by F which called by A.
@@ -37,12 +33,6 @@ func NewSingleFlight() SingleFlight {
 	}
 	}
 }
 }
 
 
-// NewSharedCalls returns a SingleFlight.
-// Deprecated: use NewSingleFlight.
-func NewSharedCalls() SingleFlight {
-	return NewSingleFlight()
-}
-
 func (g *flightGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
 func (g *flightGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
 	c, done := g.createCall(key)
 	c, done := g.createCall(key)
 	if done {
 	if done {