Skip to content

Commit 42286d3

Browse files
kazhangfacebook-github-bot
authored andcommitted
Use current device id in dist.barrier
Summary: Pull Request resolved: facebookresearch#3350 `get_local_rank` relies on a global variable set by Detectron2's `launch` utils. Since other frameworks might use Detectron2's distribute utils but don't launch with Detectron2's `launch` utils. Use `torch.cuda.current_device` to get the current device instead. Reviewed By: HarounH, ppwwyyxx Differential Revision: D30233746 fbshipit-source-id: 0b140ed5c1e7cd87ccf05235127f338ffc40a53d
1 parent 3924daf commit 42286d3

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

detectron2/engine/launch.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ def _distributed_worker(
116116
if i == machine_rank:
117117
comm._LOCAL_PROCESS_GROUP = pg
118118

119+
assert num_gpus_per_machine <= torch.cuda.device_count()
120+
torch.cuda.set_device(local_rank)
121+
119122
# synchronize is needed here to prevent a possible timeout after calling init_process_group
120123
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
121124
comm.synchronize()
122125

123-
assert num_gpus_per_machine <= torch.cuda.device_count()
124-
torch.cuda.set_device(local_rank)
125-
126126
main_func(*args)

detectron2/utils/comm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def synchronize():
8383
if dist.get_backend() == dist.Backend.NCCL and TORCH_VERSION >= (1, 8):
8484
# This argument is needed to avoid warnings.
8585
# It's valid only for NCCL backend.
86-
dist.barrier(device_ids=[get_local_rank()])
86+
dist.barrier(device_ids=[torch.cuda.current_device()])
8787
else:
8888
dist.barrier()
8989

0 commit comments

Comments
 (0)