|
13 | 13 | import pytorch_lightning as pl
|
14 | 14 | import torch
|
15 | 15 | import torch.distributed as dist
|
16 |
| -from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator |
17 |
| -from pytorch_lightning.cluster_environments import TorchElasticEnvironment |
| 16 | +import torch.distributed as torch_distrib |
| 17 | +from pytorch_lightning.plugins.training_type import DDPPlugin |
18 | 18 | from torch.utils.data import DataLoader
|
19 | 19 |
|
20 | 20 | from transformers import (
|
|
36 | 36 | import ray
|
37 | 37 | from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
|
38 | 38 |
|
39 |
| - |
40 | 39 | from callbacks_rag import ( # noqa: E402 # isort:skipq
|
41 | 40 | get_checkpoint_callback,
|
42 | 41 | get_early_stopping_callback,
|
@@ -74,27 +73,19 @@ def __init__(self, *args, **kwargs):
|
74 | 73 | self.__dict__ = self
|
75 | 74 |
|
76 | 75 |
|
77 |
| -# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule` |
78 |
| -# is no longer used, and is moved into DDPAccelerator instead. |
79 |
| -# We override DDPAccelerator to add our custom logic for initializing the |
80 |
| -# retriever. |
81 |
| -# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py |
82 |
| - |
83 |
| - |
84 |
| -class CustomAccel(DDPAccelerator): |
85 |
| - def __init__(self, trainer=None, **kwargs): |
86 |
| - # Trainer is set later. |
87 |
| - super().__init__(trainer, **kwargs) |
| 76 | +class CustomDDP(DDPPlugin): |
| 77 | + def init_ddp_connection(self, global_rank=None, world_size=None) -> None: |
| 78 | + module = self.model |
| 79 | + global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank() |
| 80 | + world_size = world_size if world_size is not None else self.cluster_environment.world_size() |
| 81 | + os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() |
| 82 | + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) |
| 83 | + if not torch.distributed.is_initialized(): |
| 84 | + logger.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") |
| 85 | + torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) |
88 | 86 |
|
89 |
| - def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True): |
90 |
| - logger.info("Custom init_ddp_connection.") |
91 |
| - module = self.trainer.model |
92 |
| - if self.cluster_environment is None: |
93 |
| - self.cluster_environment = TorchElasticEnvironment() |
94 |
| - self.distributed_port = module.hparams.distributed_port |
95 |
| - os.environ["MASTER_PORT"] = str(self.distributed_port) |
96 |
| - super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks) |
97 | 87 | if module.is_rag_model:
|
| 88 | + self.distributed_port = module.hparams.distributed_port |
98 | 89 | if module.distributed_retriever == "pytorch":
|
99 | 90 | module.model.rag.retriever.init_retrieval(self.distributed_port)
|
100 | 91 | elif module.distributed_retriever == "ray" and global_rank == 0:
|
@@ -594,7 +585,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
594 | 585 | checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
595 | 586 | early_stopping_callback=es_callback,
|
596 | 587 | logger=training_logger,
|
597 |
| - accelerator=CustomAccel() if args.gpus > 1 else None, |
| 588 | + custom_ddp_plugin=CustomDDP() if args.gpus > 1 else None, |
598 | 589 | profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
|
599 | 590 | )
|
600 | 591 | pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
|
0 commit comments