Skip to content

Commit e33085d

Browse files
authored
updated the original RAG implementation to be compatible with latest Pytorch-Lightning (#11806)
* updated the original RAG implementation to be compatible with the latest PL version * updated the requirements.txt file * execute make style * code quality test * code quality * conflix resolved in requirement.txt * code quality * changed the MyDDP class name to CustomDDP
1 parent 70f88ee commit e33085d

File tree

5 files changed

+26
-38
lines changed

5 files changed

+26
-38
lines changed

Diff for: examples/research_projects/rag/callbacks_rag.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import os
32
from pathlib import Path
43

54
import numpy as np
@@ -34,9 +33,10 @@ def get_checkpoint_callback(output_dir, metric):
3433
)
3534

3635
checkpoint_callback = ModelCheckpoint(
37-
filepath=os.path.join(output_dir, exp),
36+
dirpath=output_dir,
37+
filename=exp,
3838
monitor=f"val_{metric}",
39-
mode="max",
39+
mode="min",
4040
save_top_k=3,
4141
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
4242
)

Diff for: examples/research_projects/rag/distributed_ray_retriever.py

-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import ray
55
from transformers import RagConfig, RagRetriever, RagTokenizer
6-
from transformers.file_utils import requires_datasets, requires_faiss
76
from transformers.models.rag.retrieval_rag import CustomHFIndex
87

98

@@ -134,8 +133,6 @@ def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
134133

135134
@classmethod
136135
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
137-
requires_datasets(cls)
138-
requires_faiss(cls)
139136
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
140137
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
141138
question_encoder_tokenizer = rag_tokenizer.question_encoder

Diff for: examples/research_projects/rag/finetune_rag.py

+14-23
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import pytorch_lightning as pl
1414
import torch
1515
import torch.distributed as dist
16-
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
17-
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
16+
import torch.distributed as torch_distrib
17+
from pytorch_lightning.plugins.training_type import DDPPlugin
1818
from torch.utils.data import DataLoader
1919

2020
from transformers import (
@@ -36,7 +36,6 @@
3636
import ray
3737
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
3838

39-
4039
from callbacks_rag import ( # noqa: E402 # isort:skipq
4140
get_checkpoint_callback,
4241
get_early_stopping_callback,
@@ -74,27 +73,19 @@ def __init__(self, *args, **kwargs):
7473
self.__dict__ = self
7574

7675

77-
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
78-
# is no longer used, and is moved into DDPAccelerator instead.
79-
# We override DDPAccelerator to add our custom logic for initializing the
80-
# retriever.
81-
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py
82-
83-
84-
class CustomAccel(DDPAccelerator):
85-
def __init__(self, trainer=None, **kwargs):
86-
# Trainer is set later.
87-
super().__init__(trainer, **kwargs)
76+
class CustomDDP(DDPPlugin):
77+
def init_ddp_connection(self, global_rank=None, world_size=None) -> None:
78+
module = self.model
79+
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
80+
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
81+
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
82+
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
83+
if not torch.distributed.is_initialized():
84+
logger.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
85+
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
8886

89-
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
90-
logger.info("Custom init_ddp_connection.")
91-
module = self.trainer.model
92-
if self.cluster_environment is None:
93-
self.cluster_environment = TorchElasticEnvironment()
94-
self.distributed_port = module.hparams.distributed_port
95-
os.environ["MASTER_PORT"] = str(self.distributed_port)
96-
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
9787
if module.is_rag_model:
88+
self.distributed_port = module.hparams.distributed_port
9889
if module.distributed_retriever == "pytorch":
9990
module.model.rag.retriever.init_retrieval(self.distributed_port)
10091
elif module.distributed_retriever == "ray" and global_rank == 0:
@@ -594,7 +585,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
594585
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
595586
early_stopping_callback=es_callback,
596587
logger=training_logger,
597-
accelerator=CustomAccel() if args.gpus > 1 else None,
588+
custom_ddp_plugin=CustomDDP() if args.gpus > 1 else None,
598589
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
599590
)
600591
pickle_save(model.hparams, model.output_dir / "hparams.pkl")

Diff for: examples/research_projects/rag/lightning_base.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def total_steps(self) -> int:
167167
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
168168
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
169169

170-
def setup(self, mode):
171-
if mode == "test":
170+
def setup(self, stage):
171+
if stage == "test":
172172
self.dataset_size = len(self.test_dataloader().dataset)
173173
else:
174174
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
@@ -341,6 +341,7 @@ def generic_train(
341341
args: argparse.Namespace,
342342
early_stopping_callback=None,
343343
logger=True, # can pass WandbLogger() here
344+
custom_ddp_plugin=None,
344345
extra_callbacks=[],
345346
checkpoint_callback=None,
346347
logging_callback=None,
@@ -370,18 +371,17 @@ def generic_train(
370371
train_params["amp_level"] = args.fp16_opt_level
371372

372373
if args.gpus > 1:
373-
train_params["distributed_backend"] = "ddp"
374+
train_params["accelerator"] = "ddp"
374375

375376
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
376-
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
377-
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
377+
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
378378

379379
trainer = pl.Trainer.from_argparse_args(
380380
args,
381381
weights_summary=None,
382-
callbacks=[logging_callback] + extra_callbacks,
382+
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
383+
plugins=[custom_ddp_plugin],
383384
logger=logger,
384-
checkpoint_callback=checkpoint_callback,
385385
**train_params,
386386
)
387387

Diff for: examples/research_projects/rag/requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ datasets >= 1.0.1
33
psutil >= 5.7.0
44
torch >= 1.4.0
55
transformers
6-
pytorch-lightning==1.0.4
7-
GitPython
6+
pytorch-lightning==1.3.1
7+
GitPython

0 commit comments

Comments
 (0)