Skip to content

Commit c754c41

Browse files
ola13patrick-s-h-lewisAleksandra PiktusAleksandra PiktusAleksandra Piktus
authored
RAG (#6813)
* added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * Formatting / renaming prior to actual work * First commit * improve comments * Retrieval evaluation scripts * refactor to include modeling outputs + MPI retriever * Fix rag-token model + refactor * Various fixes + finetuning logic * use_bos fix * Retrieval refactor * Finetuning refactoring and cleanup * Add documentation and cleanup * Remove set_up_rag_env.sh file * Fix retrieval wit HF index * Fix import errors * Fix quality errors * Refactor as per suggestions in #6813 (comment) * fix quality * Fix RAG Sequence generation * minor cleanup plus initial tests * fix test * fix tests 2 * Comments fix * post-merge fixes * Improve readme + post-rebase refactor * Extra dependencied for tests * Fix tests * Fix tests 2 * Refactor test requirements * Fix tests 3 * Post-rebase refactor * rename nlp->datasets * RAG integration tests * add tokenizer to slow integration test and allow retriever to run on cpu * add tests; fix position ids warning * change structure * change structure * add from encoder generator * save working solution * make all integration tests pass * add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained * don't save paths * delete unnecessary imports * pass config to AutoTokenizer.from_pretrained for Rag tokenizers * init wiki_dpr only once * hardcode legacy index and passages paths (todo: add the right urls) * finalize config * finalize retriver api and config api * LegacyIndex index download refactor * add dpr to autotokenizer * make from pretrained more flexible * fix ragfortokengeneration * small name changes in tokenizer * add labels to models * change default index name * add retrieval tests * finish token generate * align test with previous version and make all tests pass * add tests * finalize tests * implement thoms suggestions * add first version of test * make first tests work * make retriever platform agnostic * naming * style * add legacy index URL * docstrings + simple retrieval test for distributed * clean model api * add doc_ids to retriever's outputs * fix retrieval tests * finish model outputs * finalize model api * fix generate problem for rag * fix generate for other modles * fix some tests * save intermediate * set generate to default * big refactor generate * delete rag_api * correct pip faiss install * fix auto tokenization test * fix faiss install * fix test * move the distributed logic to examples * model page * docs * finish tests * fix dependencies * fix import in __init__ * Refactor eval_rag and finetune scripts * start docstring * add psutil to test * fix tf test * move require torch to top * fix retrieval test * align naming * finish automodel * fix repo consistency * test ragtokenizer save/load * add rag model output docs * fix ragtokenizer save/load from pretrained * fix tokenizer dir * remove torch in retrieval * fix docs * fixe finetune scripts * finish model docs * finish docs * remove auto model for now * add require torch * remove solved todos * integrate sylvains suggestions * sams comments * correct mistake on purpose * improve README * Add generation test cases * fix rag token * clean token generate * fix test * add note to test * fix attention mask * add t5 test for rag * Fix handling prefix in finetune.py * don't overwrite index_name Co-authored-by: Patrick Lewis <plewis@fb.com> Co-authored-by: Aleksandra Piktus <piktus@devfair0141.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5102.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5067.h2.fair> Co-authored-by: Your Name <you@example.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
1 parent 1ee2194 commit c754c41

37 files changed

+5176
-32
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ __pycache__/
1111
# tests and logs
1212
tests/fixtures
1313
logs/
14+
lightning_logs/
1415

1516
# Distribution / packaging
1617
.Python
@@ -139,6 +140,7 @@ runs
139140
/wandb
140141
/examples/runs
141142
/examples/**/*.args
143+
/examples/rag/sweep
142144

143145
# data
144146
/data

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ conversion utilities for the following models:
231231
model_doc/lxmert
232232
model_doc/bertgeneration
233233
model_doc/layoutlm
234+
model_doc/rag
234235
internal/modeling_utils
235236
internal/tokenization_utils
236237
internal/pipelines_utils

docs/source/model_doc/rag.rst

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
RAG
2+
----------------------------------------------------
3+
4+
Overview
5+
~~~~~~~~~~~~~~~~~~~~~
6+
7+
Retrieval-augmented generation ("RAG") models combine the powers of pretrained dense retrieval (DPR) and Seq2Seq models.
8+
RAG models retrieve docs, pass them to a seq2seq model, then marginalize to generate outputs.
9+
The retriever and seq2seq modules are initialized from pretrained models, and fine-tuned jointly, allowing both retrieval and generation to adapt to downstream tasks.
10+
11+
It is based on the paper `Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`__ by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
12+
13+
The abstract from the paper is the following:
14+
15+
*Large pre-trained language models have been shown to store factual knowledge
16+
in their parameters, and achieve state-of-the-art results when fine-tuned on
17+
downstream NLP tasks. However, their ability to access and precisely manipulate
18+
knowledge is still limited, and hence on knowledge-intensive tasks, their
19+
performance lags behind task-specific architectures. Additionally, providing
20+
provenance for their decisions and updating their world knowledge remain open
21+
research problems. Pre-trained models with a differentiable access mechanism to
22+
explicit nonparametric memory can overcome this issue, but have so far been only
23+
investigated for extractive downstream tasks. We explore a general-purpose
24+
fine-tuning recipe for retrieval-augmented generation (RAG) — models which combine
25+
pre-trained parametric and non-parametric memory for language generation. We
26+
introduce RAG models where the parametric memory is a pre-trained seq2seq model and
27+
the non-parametric memory is a dense vector index of Wikipedia, accessed with
28+
a pre-trained neural retriever. We compare two RAG formulations, one which
29+
conditions on the same retrieved passages across the whole generated sequence, the
30+
other can use different passages per token. We fine-tune and evaluate our models
31+
on a wide range of knowledge-intensive NLP tasks and set the state-of-the-art
32+
on three open domain QA tasks, outperforming parametric seq2seq models and
33+
task-specific retrieve-and-extract architectures. For language generation tasks, we
34+
find that RAG models generate more specific, diverse and factual language than a
35+
state-of-the-art parametric-only seq2seq baseline.*
36+
37+
38+
39+
RagConfig
40+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41+
42+
.. autoclass:: transformers.RagConfig
43+
:members:
44+
45+
46+
RagTokenizer
47+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
48+
49+
.. autoclass:: transformers.RagTokenizer
50+
:members:
51+
52+
53+
Rag specific outputs
54+
~~~~~~~~~~~~~~~~~~~~~
55+
56+
.. autoclass:: transformers.modeling_rag.RetrievAugLMMarginOutput
57+
:members:
58+
59+
.. autoclass:: transformers.modeling_rag.RetrievAugLMOutput
60+
:members:
61+
62+
63+
RAGRetriever
64+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
65+
66+
.. autoclass:: transformers.RagRetriever
67+
:members:
68+
69+
70+
RagModel
71+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72+
73+
.. autoclass:: transformers.RagModel
74+
:members: forward
75+
76+
77+
RagSequenceForGeneration
78+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79+
80+
.. autoclass:: transformers.RagSequenceForGeneration
81+
:members: forward, generate
82+
83+
84+
RagTokenForGeneration
85+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
86+
87+
.. autoclass:: transformers.RagTokenForGeneration
88+
:members: forward, generate

docs/source/model_summary.rst

+22-1
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ DPR
654654
<a href="https://huggingface.co/models?filter=dpr">
655655
<img alt="Models" src="https://img.shields.io/badge/All_model_pages-dpr-blueviolet">
656656
</a>
657-
<a href="model_doc/ctrl.dpr">
657+
<a href="model_doc/dpr.html">
658658
<img alt="Doc" src="https://img.shields.io/badge/Model_documentation-dpr-blueviolet">
659659
</a>
660660

@@ -672,6 +672,27 @@ DPR consists in three models:
672672

673673
DPR's pipeline (not implemented yet) uses a retrieval step to find the top k contexts given a certain question, and then it calls the reader with the question and the retrieved documents to get the answer.
674674

675+
RAG
676+
----------------------------------------------
677+
678+
.. raw:: html
679+
680+
<a href="https://huggingface.co/models?filter=rag">
681+
<img alt="Models" src="https://img.shields.io/badge/All_model_pages-rag-blueviolet">
682+
</a>
683+
<a href="model_doc/rag.html">
684+
<img alt="Doc" src="https://img.shields.io/badge/Model_documentation-rag-blueviolet">
685+
</a>
686+
687+
`Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`_,
688+
Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela
689+
690+
Retrieval-augmented generation ("RAG") models combine the powers of pretrained dense retrieval (DPR) and Seq2Seq models.
691+
RAG models retrieve docs, pass them to a seq2seq model, then marginalize to generate outputs.
692+
The retriever and seq2seq modules are initialized from pretrained models, and fine-tuned jointly, allowing both retrieval and generation to adapt to downstream tasks.
693+
694+
The two models RAG-Token and RAG-Sequence are available for generation.
695+
675696
More technical aspects
676697
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
677698

examples/lightning_base.py

+2
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@ def generic_train(
366366
if args.gpus > 1:
367367
train_params["distributed_backend"] = "ddp"
368368

369+
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
370+
369371
trainer = pl.Trainer.from_argparse_args(
370372
args,
371373
weights_summary=None,

examples/longform-qa/eli5_app.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import datasets
2-
import faiss
32
import numpy as np
43
import streamlit as st
54
import torch
65
from elasticsearch import Elasticsearch
76

7+
import faiss
88
import transformers
99
from eli5_utils import (
1010
embed_questions_for_retrieval,

examples/longform-qa/eli5_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from time import time
66

77
import datasets # noqa: F401
8-
import faiss # noqa: F401
98
import numpy as np
109
import pandas as pd
1110
import torch
@@ -15,6 +14,7 @@
1514
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
1615
from tqdm import tqdm
1716

17+
import faiss # noqa: F401
1818
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup
1919

2020

examples/rag/README.md

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Intro
2+
RAG is a seq2seq model which encapsulates two core components: a question encoder and a generator.
3+
During a forward pass, we encode the input with the question encoder and pass it
4+
to the retriever to extract relevant context documents. The documents are then prepended to the input.
5+
Such contextualized inputs is passed to the generator.
6+
7+
The question encoder can be any `autoencoding` model, preferably :obj:`~transformers.DPRQuestionEncoder`, and the generator can be any `seq2seq` model, preferably :obj:`~transformers.BartForConditionalGeneration`.
8+
9+
The model can be initialized with a :obj:`~transformers.RagRetriever` for end-to-end generation or used in combination with the outputs of a retriever in multiple steps - see examples for more details.
10+
The model is compatible any `autoencoding` model as the ``question_encoder`` and any `seq2seq` model with language model head as the ``generator``.
11+
The model has been tested with :class:`~transformers.DPRQuestionEncoder` as the ``question_encoder`` and :class:`~transformers.BartForConditionalGeneration` or :class:`~transformers.T5ForConditionalGeneration` as the ``generator``.
12+
13+
RAG models were released with the paper `Retrieval-Augmented Generation for
14+
Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`_ by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.
15+
16+
17+
# Finetuning
18+
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq).
19+
Follow instructions there regarding data preprocessing. A sample finetuning command:
20+
21+
```
22+
python examples/rag/finetune.py \
23+
--data_dir $DATA_DIR \
24+
--output_dir $OUTPUT_DIR \
25+
--model_name_or_path $MODEL_NAME_OR_PATH \
26+
--model_type rag_sequence \
27+
--fp16 \
28+
--gpus 8
29+
```
30+
31+
32+
# Evaluation
33+
Apart from the parameters specifying the model to evaluate and some extra parameters, the evaluation script expects paths to two files:
34+
- `evaluation_set` - a path to a file specifying the evaluation dataset, a single datapoint per line, e.g.
35+
```who is the owner of reading football club```
36+
- `gold_data_path` - a path to a file contaning ground truth answers for datapoints from the `evaluation_set`.
37+
38+
We expect the following formats of the gold data file:
39+
40+
- for e2e evaluation, we support two formats of the gold file:
41+
- `qa` - where a single line in the following format: input [tab] output_list, e.g.:
42+
```
43+
who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiuli', 'Yongge Dai']
44+
```
45+
- `ans` - where a single line of the gold file contains the expected output string, e.g.:
46+
```
47+
Xiu Li Dai
48+
```
49+
50+
- for retrieval evaluation, we expect a tab-separated list of Wikipedia page titles constituting positive contexts for a given query, e.g. given a question `who sings does he love me with reba`, a line with ground truth retrieval data could look as follows:
51+
```
52+
Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greatest Hits Volume Two (Reba McEntire album) Shoot for the Moon (album)
53+
```
54+
55+
## Retrieval evaluation
56+
57+
We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45).
58+
59+
1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz.
60+
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
61+
```
62+
python examples/rag/parse_dpr_relevance_data.py --src_path path/to/unziped/biencoder-nq-dev.json --evaluation_set path/to/output/biencoder-nq-dev.questions --gold_data_path path/to/output/biencoder-nq-dev.pages
63+
```
64+
3. Run evaluation:
65+
```
66+
python examples/rag/eval_rag.py \
67+
--model_name_or_path $MODEL_NAME_OR_PATH \ # model name or path of the model we're evaluating
68+
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
69+
--evaluation_set path/to/output/biencoder-nq-dev.questions \ # an input dataset for evaluation
70+
--gold_data_path path/to/output/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
71+
--predictions_path path/to/retrieval_preds.tsv \ # name of file in which predictions will be stored
72+
--eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation
73+
--recalculate # if predictions_filename already exists, and this option is set - we regenerate the answers, otherwise we reuse the predicsion file to calculate metrics.
74+
```
75+
76+
77+
## End-to-end evaluation
78+
```
79+
python examples/rag/eval_rag.py \
80+
--model_name_or_path $MODEL_NAME_OR_PATH \
81+
--model_type rag_sequence \
82+
--evaluation_set path/to/test.source \
83+
--gold_data_path path/to/gold_data \
84+
--predictions_path path/to/e2e_preds.txt \
85+
--eval_mode e2e \ # indicates whether we're performing retrieval evaluation or e2e evaluation (default)
86+
--n_docs 5 \ # You can experiment with retrieving different number of documents at evaluation time
87+
--print_predictions
88+
```

examples/rag/__init__.py

Whitespace-only changes.

examples/rag/callbacks.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import logging
2+
import os
3+
4+
from pytorch_lightning.callbacks import ModelCheckpoint
5+
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def get_checkpoint_callback(output_dir, metric):
11+
"""Saves the best model by validation EM score."""
12+
if metric == "rouge2":
13+
exp = "{val_avg_rouge2:.4f}-{step_count}"
14+
elif metric == "bleu":
15+
exp = "{val_avg_bleu:.4f}-{step_count}"
16+
elif metric == "em":
17+
exp = "{val_avg_em:.4f}-{step_count}"
18+
else:
19+
raise NotImplementedError(
20+
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
21+
)
22+
23+
checkpoint_callback = ModelCheckpoint(
24+
filepath=os.path.join(output_dir, exp),
25+
monitor=f"val_{metric}",
26+
mode="max",
27+
save_top_k=3,
28+
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
29+
)
30+
return checkpoint_callback

0 commit comments

Comments
 (0)