-
Notifications
You must be signed in to change notification settings - Fork 28.4k
/
Copy pathuse_own_knowledge_dataset.py
199 lines (162 loc) · 7.66 KB
/
use_own_knowledge_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional
import torch
from datasets import load_dataset
import faiss
from transformers import (
DPRContextEncoder,
DPRContextEncoderTokenizerFast,
HfArgumentParser,
RagRetriever,
RagSequenceForGeneration,
RagTokenizer,
)
logger = logging.getLogger(__name__)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
def split_text(text: str, n=100, character=" ") -> List[str]:
"""Split the text every ``n``-th occurence of ``character``"""
text = text.split(character)
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
def split_documents(documents: dict) -> dict:
"""Split documents into passages"""
titles, texts = [], []
for title, text in zip(documents["title"], documents["text"]):
for passage in split_text(text):
titles.append(title)
texts.append(passage)
return {"title": titles, "text": texts}
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
"""Compute the DPR embeddings of document passages"""
input_ids = ctx_tokenizer(
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
)["input_ids"]
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
return {"embeddings": embeddings.detach().cpu().numpy()}
def main(
rag_example_args: "RagExampleArguments",
processing_args: "ProcessingArguments",
index_hnsw_args: "IndexHnswArguments",
):
######################################
logger.info("Step 1 - Create the dataset")
######################################
# The dataset needed for RAG must have three columns:
# - title (string): title of the document
# - text (string): text of a passage of the document
# - embeddings (array of dimension d): DPR representation of the passage
# Let's say you have documents in tab-separated csv files with columns "title" and "text"
assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file"
# You can load a Dataset object this way
dataset = load_dataset(
"csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
)
# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files
# Then split the documents into passages of 100 words
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)
# And compute the embeddings
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
dataset = dataset.map(
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
batched=True,
batch_size=processing_args.batch_size,
)
# And finally save your dataset
passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset")
dataset.save_to_disk(passages_path)
# from datasets import load_from_disk
# dataset = load_from_disk(passages_path) # to reload the dataset
######################################
logger.info("Step 2 - Index the dataset")
######################################
# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
dataset.add_faiss_index("embeddings", custom_index=index)
# And save the index
index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
dataset.get_index("embeddings").save(index_path)
# dataset.load_faiss_index("embeddings", index_path) # to reload the index
######################################
logger.info("Step 3 - Load RAG")
######################################
# Easy way to load the model
retriever = RagRetriever.from_pretrained(
rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset
)
model = RagSequenceForGeneration.from_pretrained(rag_example_args.rag_model_name, retriever=retriever)
tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name)
# For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
# retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)
######################################
logger.info("Step 4 - Have fun")
######################################
question = rag_example_args.question or "What does Moses' rod turn into ?"
input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
generated = model.generate(input_ids)
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
logger.info("Q: " + question)
logger.info("A: " + generated_string)
@dataclass
class RagExampleArguments:
csv_path: str = field(
default=str(Path(__file__).parent / "test_data" / "my_knowledge_dataset.csv"),
metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
)
question: Optional[str] = field(
default=None,
metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
)
rag_model_name: str = field(
default="facebook/rag-sequence-nq",
metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
)
dpr_ctx_encoder_model_name: str = field(
default="facebook/dpr-ctx_encoder-multiset-base",
metadata={
"help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
},
)
output_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
)
@dataclass
class ProcessingArguments:
num_proc: Optional[int] = field(
default=None,
metadata={
"help": "The number of processes to use to split the documents into passages. Default is single process."
},
)
batch_size: int = field(
default=16,
metadata={
"help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
},
)
@dataclass
class IndexHnswArguments:
d: int = field(
default=768,
metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
)
m: int = field(
default=128,
metadata={
"help": "The number of bi-directional links created for every new element during the HNSW index construction."
},
)
if __name__ == "__main__":
logging.basicConfig(level=logging.WARNING)
logger.setLevel(logging.INFO)
parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
rag_example_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()
with TemporaryDirectory() as tmp_dir:
rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir
main(rag_example_args, processing_args, index_hnsw_args)