trie.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package stringx
  2. import "github.com/tal-tech/go-zero/core/lang"
  3. type (
  4. Trie interface {
  5. Filter(text string) (string, []string, bool)
  6. FindKeywords(text string) []string
  7. }
  8. trieNode struct {
  9. node
  10. }
  11. scope struct {
  12. start int
  13. stop int
  14. }
  15. )
  16. func NewTrie(words []string) Trie {
  17. n := new(trieNode)
  18. for _, word := range words {
  19. n.add(word)
  20. }
  21. return n
  22. }
  23. func (n *trieNode) Filter(text string) (sentence string, keywords []string, found bool) {
  24. chars := []rune(text)
  25. if len(chars) == 0 {
  26. return text, nil, false
  27. }
  28. scopes := n.findKeywordScopes(chars)
  29. keywords = n.collectKeywords(chars, scopes)
  30. for _, match := range scopes {
  31. // we don't care about overlaps, not bringing a performance improvement
  32. n.replaceWithAsterisk(chars, match.start, match.stop)
  33. }
  34. return string(chars), keywords, len(keywords) > 0
  35. }
  36. func (n *trieNode) FindKeywords(text string) []string {
  37. chars := []rune(text)
  38. if len(chars) == 0 {
  39. return nil
  40. }
  41. scopes := n.findKeywordScopes(chars)
  42. return n.collectKeywords(chars, scopes)
  43. }
  44. func (n *trieNode) collectKeywords(chars []rune, scopes []scope) []string {
  45. set := make(map[string]lang.PlaceholderType)
  46. for _, v := range scopes {
  47. set[string(chars[v.start:v.stop])] = lang.Placeholder
  48. }
  49. var i int
  50. keywords := make([]string, len(set))
  51. for k := range set {
  52. keywords[i] = k
  53. i++
  54. }
  55. return keywords
  56. }
  57. func (n *trieNode) findKeywordScopes(chars []rune) []scope {
  58. var scopes []scope
  59. size := len(chars)
  60. start := -1
  61. for i := 0; i < size; i++ {
  62. child, ok := n.children[chars[i]]
  63. if !ok {
  64. continue
  65. }
  66. if start < 0 {
  67. start = i
  68. }
  69. if child.end {
  70. scopes = append(scopes, scope{
  71. start: start,
  72. stop: i + 1,
  73. })
  74. }
  75. for j := i + 1; j < size; j++ {
  76. grandchild, ok := child.children[chars[j]]
  77. if !ok {
  78. break
  79. }
  80. child = grandchild
  81. if child.end {
  82. scopes = append(scopes, scope{
  83. start: start,
  84. stop: j + 1,
  85. })
  86. }
  87. }
  88. start = -1
  89. }
  90. return scopes
  91. }
  92. func (n *trieNode) replaceWithAsterisk(chars []rune, start, stop int) {
  93. for i := start; i < stop; i++ {
  94. chars[i] = '*'
  95. }
  96. }