-
Notifications
You must be signed in to change notification settings - Fork 28.5k
/
Copy pathtest_retrieval_rag.py
387 lines (351 loc) · 17.2 KB
/
test_retrieval_rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import pickle
import shutil
import tempfile
from unittest import TestCase
from unittest.mock import patch
import numpy as np
from datasets import Dataset
from transformers import is_faiss_available
from transformers.models.bart.configuration_bart import BartConfig
from transformers.models.bart.tokenization_bart import BartTokenizer
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.models.dpr.configuration_dpr import DPRConfig
from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
from transformers.models.rag.configuration_rag import RagConfig
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
from transformers.testing_utils import (
require_datasets,
require_faiss,
require_sentencepiece,
require_tokenizers,
require_torch,
)
if is_faiss_available():
import faiss
@require_faiss
@require_datasets
class RagRetrieverTest(TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
self.retrieval_vector_size = 8
# DPR tok
vocab_tokens = [
"[UNK]",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
]
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
os.makedirs(dpr_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
# BART tok
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
os.makedirs(bart_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
def get_dpr_ctx_encoder_tokenizer(self) -> DPRContextEncoderTokenizer:
return DPRContextEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
def get_bart_tokenizer(self) -> BartTokenizer:
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def get_dummy_dataset(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
return dataset
def get_dummy_canonical_hf_index_retriever(self):
dataset = self.get_dummy_dataset()
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
)
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
mock_load_dataset.return_value = dataset
retriever = RagRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
)
return retriever
def get_dummy_custom_hf_index_retriever(self, from_disk: bool):
dataset = self.get_dummy_dataset()
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="custom",
)
if from_disk:
config.passages_path = os.path.join(self.tmpdirname, "dataset")
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
dataset.drop_index("embeddings")
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
del dataset
retriever = RagRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
)
else:
retriever = RagRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
index=CustomHFIndex(config.retrieval_vector_size, dataset),
)
return retriever
def get_dummy_legacy_index_retriever(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb"))
passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset}
pickle.dump(passages, open(passages_file_name, "wb"))
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="legacy",
index_path=self.tmpdirname,
)
retriever = RagRetriever(
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
)
return retriever
def test_canonical_hf_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_canonical_hf_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_canonical_hf_index_retriever_save_and_from_pretrained(self):
retriever = self.get_dummy_canonical_hf_index_retriever()
with tempfile.TemporaryDirectory() as tmp_dirname:
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
mock_load_dataset.return_value = self.get_dummy_dataset()
retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)
def test_custom_hf_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_custom_hf_index_retriever_save_and_from_pretrained(self):
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)
def test_custom_hf_index_retriever_retrieve_from_disk(self):
n_docs = 1
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=True)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_custom_hf_index_retriever_save_and_from_pretrained_from_disk(self):
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=True)
with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)
def test_legacy_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_legacy_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["text", "title"])
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_legacy_hf_index_retriever_save_and_from_pretrained(self):
retriever = self.get_dummy_legacy_index_retriever()
with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)
@require_torch
@require_tokenizers
@require_sentencepiece
def test_hf_index_retriever_call(self):
import torch
n_docs = 1
retriever = self.get_dummy_canonical_hf_index_retriever()
question_input_ids = [[5, 7], [10, 11]]
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever(question_input_ids, hidden_states, prefix=retriever.config.generator.prefix, n_docs=n_docs)
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
out["context_input_ids"],
out["context_attention_mask"],
out["retrieved_doc_embeds"],
)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertIsInstance(context_input_ids, list)
self.assertIsInstance(context_attention_mask, list)
self.assertIsInstance(retrieved_doc_embeds, np.ndarray)
out = retriever(
question_input_ids,
hidden_states,
prefix=retriever.config.generator.prefix,
n_docs=n_docs,
return_tensors="pt",
)
context_input_ids, context_attention_mask, retrieved_doc_embeds, doc_ids = ( # noqa: F841
out["context_input_ids"],
out["context_attention_mask"],
out["retrieved_doc_embeds"],
out["doc_ids"],
)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertIsInstance(context_input_ids, torch.Tensor)
self.assertIsInstance(context_attention_mask, torch.Tensor)
self.assertIsInstance(retrieved_doc_embeds, torch.Tensor)
@require_torch
@require_tokenizers
@require_sentencepiece
def test_custom_hf_index_end2end_retriever_call(self):
context_encoder_tokenizer = self.get_dpr_ctx_encoder_tokenizer()
n_docs = 1
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
retriever.set_ctx_encoder_tokenizer(context_encoder_tokenizer)
question_input_ids = [[5, 7], [10, 11]]
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever(question_input_ids, hidden_states, prefix=retriever.config.generator.prefix, n_docs=n_docs)
self.assertEqual(
len(out), 6
) # check whether the retriever output consist of 6 attributes including tokenized docs
self.assertEqual(
all(k in out for k in ("tokenized_doc_ids", "tokenized_doc_attention_mask")), True
) # check for doc token related keys in dictionary.