33
33
34
34
35
35
class BM25Retriever (BaseRetriever ):
36
- """A BM25 retriever that uses the BM25 algorithm to retrieve nodes.
36
+ r """A BM25 retriever that uses the BM25 algorithm to retrieve nodes.
37
37
38
38
Args:
39
39
nodes (List[BaseNode], optional):
@@ -52,6 +52,10 @@ class BM25Retriever(BaseRetriever):
52
52
The objects to retrieve. Defaults to None.
53
53
object_map (dict, optional):
54
54
A map of object IDs to nodes. Defaults to None.
55
+ token_pattern (str, optional):
56
+ The token pattern to use. Defaults to (?u)\\b\\w\\w+\\b.
57
+ skip_stemming (bool, optional):
58
+ Whether to skip stemming. Defaults to False.
55
59
verbose (bool, optional):
56
60
Whether to show progress. Defaults to False.
57
61
"""
@@ -67,9 +71,13 @@ def __init__(
67
71
objects : Optional [List [IndexNode ]] = None ,
68
72
object_map : Optional [dict ] = None ,
69
73
verbose : bool = False ,
74
+ skip_stemming : bool = False ,
75
+ token_pattern : str = r"(?u)\b\w\w+\b" ,
70
76
) -> None :
71
77
self .stemmer = stemmer or Stemmer .Stemmer ("english" )
72
78
self .similarity_top_k = similarity_top_k
79
+ self .token_pattern = token_pattern
80
+ self .skip_stemming = skip_stemming
73
81
74
82
if existing_bm25 is not None :
75
83
self .bm25 = existing_bm25
@@ -83,7 +91,8 @@ def __init__(
83
91
corpus_tokens = bm25s .tokenize (
84
92
[node .get_content (metadata_mode = MetadataMode .EMBED ) for node in nodes ],
85
93
stopwords = language ,
86
- stemmer = self .stemmer ,
94
+ stemmer = self .stemmer if not skip_stemming else None ,
95
+ token_pattern = self .token_pattern ,
87
96
show_progress = verbose ,
88
97
)
89
98
self .bm25 = bm25s .BM25 ()
@@ -105,6 +114,8 @@ def from_defaults(
105
114
language : str = "en" ,
106
115
similarity_top_k : int = DEFAULT_SIMILARITY_TOP_K ,
107
116
verbose : bool = False ,
117
+ skip_stemming : bool = False ,
118
+ token_pattern : str = r"(?u)\b\w\w+\b" ,
108
119
# deprecated
109
120
tokenizer : Optional [Callable [[str ], List [str ]]] = None ,
110
121
) -> "BM25Retriever" :
@@ -134,6 +145,8 @@ def from_defaults(
134
145
language = language ,
135
146
similarity_top_k = similarity_top_k ,
136
147
verbose = verbose ,
148
+ skip_stemming = skip_stemming ,
149
+ token_pattern = token_pattern ,
137
150
)
138
151
139
152
def get_persist_args (self ) -> Dict [str , Any ]:
@@ -161,7 +174,10 @@ def from_persist_dir(cls, path: str, **kwargs: Any) -> "BM25Retriever":
161
174
def _retrieve (self , query_bundle : QueryBundle ) -> List [NodeWithScore ]:
162
175
query = query_bundle .query_str
163
176
tokenized_query = bm25s .tokenize (
164
- query , stemmer = self .stemmer , show_progress = self ._verbose
177
+ query ,
178
+ stemmer = self .stemmer if not self .skip_stemming else None ,
179
+ token_pattern = self .token_pattern ,
180
+ show_progress = self ._verbose ,
165
181
)
166
182
indexes , scores = self .bm25 .retrieve (
167
183
tokenized_query , k = self .similarity_top_k , show_progress = self ._verbose
0 commit comments