Skip to content

Commit 4edc921

Browse files
Yifu Wangpytorchmergebot
authored andcommitted
Introduce 3 low-latency, intra-node allreduce algorithms for small messages to PyTorch (pytorch#114001)
## Summary This PR added 3 intra-node GPU allreduce algorithms to PyTorch: - One-shot allreduce (inspired by FasterTransformer): all ranks simultaneously read and accumulate data from other ranks. - Two-shot allreduce (inspired by FasterTransformer): all ranks simultanesouly read and accumulate `1 / world_size` data from other ranks. Then all ranks read accumulated data from other ranks. (effectively one-shot reduce-scatter + one-shot all-gather). - Hybrid cube mesh allreduce (original): a one-shot allreduce variant that avoids transmission over PCIe on HCM topology. ## Micro Benchmarks ![image](https://github.com/pytorch/pytorch/assets/4156752/7bd25ffc-cd5b-4acb-bd65-b01bc136726e) ![image](https://github.com/pytorch/pytorch/assets/4156752/3ced31b4-6c31-4f34-a2d8-c072df29ae0e) ![image](https://github.com/pytorch/pytorch/assets/4156752/5b942c05-4fcc-4ec9-ae29-12c64080bb1c) ## Details The intra-node algos are organized behind `c10d::IntraNodeComm`, which is responsible for: - Managing handshaking and cuda IPC handle exchange among ranks. - Querying NVLink connection and detecting topology. - Performing algo selection based on available info. - Launching the selected allreduce kernel. `c10d::IntraNodeComm` is integrated into `c10d::ProcessGroupNCCL` as follows: - When the `ENABLE_INTRA_NODE_COMM` environment variable is set, `c10d::ProcessGroupNCCL` initialize a `c10d::IntraNodeComm` for its ranks. - If the setup is not suitable for intra-node comm (e.g. not all ranks are from the same node), the rendezvous logic guarantees all participants fall back consistently. - `c10d::ProcessGroupNCCL::allreduce` consults `c10d::IntraNodeComm` whether to use intra-node allreduce and carries out the communication accordingly. We currently detect two types of topoloies from the nNVLink connection mesh: - Fully connected: all GPU pairs has direct NVLink connection (e.g. NVSwitch or fully connected sub-set of hybrid cube mesh) - `msg <= 256KB`: one-shot allreduce. - `256KB < msg <= 10MB`: two-shot allreduce. - `msg > 10MB`: instructs the caller to fallback to NCCL. - Hybrid cube mesh - `msg <= 256KB`: one-shot allreduce. - `msg > 256KB`: instructs the caller to fallback to NCCL. ## Next Steps - Fine tune algo selection based on GPU model, topology, link speed. - Potentially optimize the two-shot allreduce impl. Accroding to FasterTransformer, two-shot allreduce is preferred until 50MB. There might be room for improvement, but PyTorch does impose more constraints: - FasterTransformer uses a single process to drive multiple devices. It can use `cudaDeviceEnablePeerAccess` enable device-level peer access. - PyTorch uses multiple process to drive multiple devices. With cuda IPC, a device can only share a specific region to other devices. This means extra copies may be unavoidable. Pull Request resolved: pytorch#114001 Approved by: https://github.com/yf225
1 parent cd47e33 commit 4edc921

File tree

12 files changed

+1363
-7
lines changed

12 files changed

+1363
-7
lines changed

BUILD.bazel

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1452,7 +1452,10 @@ cu_library(
14521452
# https://github.com/pytorch/pytorch/issues/79236
14531453
# To solve it we add it into the `caffe2_cuda`,
14541454
# this is also aligned with the CMake build.
1455-
srcs = [":caffe2_cu_srcs"] + ["torch/csrc/distributed/c10d/quantization/quantization_gpu.cu"],
1455+
srcs = [":caffe2_cu_srcs"] + [
1456+
"torch/csrc/distributed/c10d/intra_node_comm.cu",
1457+
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
1458+
],
14561459
copts = CAFFE2_COPTS + torch_cuda_half_options,
14571460
visibility = ["//visibility:public"],
14581461
deps = [
@@ -1619,6 +1622,7 @@ cc_library(
16191622
exclude = [
16201623
"torch/csrc/cuda/python_nccl.cpp",
16211624
"torch/csrc/cuda/nccl.cpp",
1625+
"torch/csrc/distributed/c10d/intra_node_comm.cu",
16221626
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
16231627
],
16241628
)) + torch_sources,

build_variables.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,8 @@ libtorch_cuda_distributed_extra_sources = [
674674
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
675675
"torch/csrc/distributed/c10d/UCCTracing.cpp",
676676
"torch/csrc/distributed/c10d/UCCUtils.cpp",
677+
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
678+
"torch/csrc/distributed/c10d/intra_node_comm.cu",
677679
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
678680
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
679681
]

c10/cuda/driver_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ void* DriverAPI::get_nvml_handle() {
3737
return nvml_hanle;
3838
}
3939

40-
DriverAPI* DriverAPI::get() {
40+
C10_EXPORT DriverAPI* DriverAPI::get() {
4141
static DriverAPI singleton = create_driver_api();
4242
return &singleton;
4343
}

c10/cuda/driver_api.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
_(cuMemCreate) \
2929
_(cuGetErrorString)
3030

31-
#define C10_NVML_DRIVER_API(_) \
32-
_(nvmlInit_v2) \
33-
_(nvmlDeviceGetHandleByPciBusId_v2) \
31+
#define C10_NVML_DRIVER_API(_) \
32+
_(nvmlInit_v2) \
33+
_(nvmlDeviceGetHandleByPciBusId_v2) \
34+
_(nvmlDeviceGetNvLinkRemoteDeviceType) \
35+
_(nvmlDeviceGetNvLinkRemotePciInfo_v2) \
3436
_(nvmlDeviceGetComputeRunningProcesses)
3537

3638
namespace c10 {

caffe2/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,10 @@ if(USE_CUDA)
641641
append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS)
642642
if(NOT WIN32)
643643
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS)
644+
set_source_files_properties(
645+
${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
646+
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
647+
)
644648
endif()
645649
endif()
646650
set_source_files_properties(

test/distributed/test_c10d_nccl.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from contextlib import contextmanager
1616
from datetime import datetime, timedelta
1717
from itertools import chain, product
18-
from unittest import mock
18+
from unittest import SkipTest, mock
1919

2020
import torch
2121
import torch.distributed as c10d
@@ -3113,6 +3113,65 @@ def test_all_reduce_coalesced_nccl(self):
31133113
for i, t in enumerate(tensors):
31143114
self.assertEqual(t, torch.full_like(t, self.world_size * (i + (self.world_size + 1.) / 2.)))
31153115

3116+
@requires_nccl()
3117+
@skip_if_lt_x_gpu(2)
3118+
@skip_if_rocm
3119+
def test_intra_node_comm_all_reduce(self):
3120+
from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter
3121+
from torch.testing._internal.common_cuda import SM80OrLater
3122+
for peer in range(self.world_size):
3123+
if peer == self.rank:
3124+
continue
3125+
if not torch._C._cuda_canDeviceAccessPeer(self.rank, peer):
3126+
raise SkipTest("Test requires p2p access")
3127+
3128+
if not SM80OrLater:
3129+
raise SkipTest("Test requires sm>=80")
3130+
3131+
store = c10d.FileStore(self.file_name, self.world_size)
3132+
os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
3133+
os.environ["TEST_INTRA_NODE_COMM"] = "1"
3134+
torch.cuda.set_device(self.rank)
3135+
c10d.init_process_group(
3136+
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
3137+
)
3138+
expect = self.world_size * (self.world_size - 1) // 2
3139+
3140+
# IntraNodeComm currently only supports sum and bf16.
3141+
# Verify that it is not used in the next two configurations.
3142+
t = torch.full((4 * 1024 // 2,), self.rank).cuda()
3143+
c10d.all_reduce(t, c10d.ReduceOp.SUM)
3144+
self.assertTrue(t.eq(expect).all())
3145+
self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
3146+
3147+
t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
3148+
c10d.all_reduce(t, c10d.ReduceOp.AVG)
3149+
self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
3150+
3151+
# Verify that IntraNodeComm is used up to 10MB
3152+
t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
3153+
c10d.all_reduce(t, c10d.ReduceOp.SUM)
3154+
self.assertTrue(t.eq(expect).all())
3155+
self.assertEqual(_get_intra_node_comm_usage_counter(), 1)
3156+
3157+
t = torch.full((512 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
3158+
c10d.all_reduce(t, c10d.ReduceOp.SUM)
3159+
self.assertTrue(t.eq(expect).all())
3160+
self.assertEqual(_get_intra_node_comm_usage_counter(), 2)
3161+
3162+
t = torch.full((10 * 1024 ** 2 // 2,), self.rank, dtype=torch.bfloat16).cuda()
3163+
c10d.all_reduce(t, c10d.ReduceOp.SUM)
3164+
self.assertTrue(t.eq(expect).all())
3165+
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
3166+
3167+
# Verify that IntraNodeComm is not used beyond 10MB
3168+
t = torch.full((10 * 1024 ** 2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda()
3169+
c10d.all_reduce(t, c10d.ReduceOp.SUM)
3170+
self.assertTrue(t.eq(expect).all())
3171+
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
3172+
3173+
c10d.destroy_process_group()
3174+
31163175
@requires_nccl()
31173176
@skip_if_lt_x_gpu(2)
31183177
def test_sequence_num_set_default_pg_nccl(self):

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
712712
terminateProcessGroup_(false),
713713
terminateHeartbeatMonitorThread_(false),
714714
collectiveDebugInfoMode_(false),
715-
uid_(process_group_id++) {
715+
uid_(process_group_id++),
716+
intraNodeComm_(initIntraNodeComm()) {
716717
TORCH_CHECK_WITH(
717718
ValueError,
718719
at::cuda::getNumGPUs() != 0,
@@ -895,6 +896,12 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
895896
#endif
896897
}
897898

899+
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> ProcessGroupNCCL::
900+
initIntraNodeComm() {
901+
return intra_node_comm::IntraNodeComm::rendezvous(
902+
store_, std::to_string(uid_), rank_, size_);
903+
}
904+
898905
void ProcessGroupNCCL::runHealthCheck() {
899906
// Run health check in a separate thread and wait on CV to handle timeouts,
900907
// since majority of getNCCLComm failures are hangs.
@@ -2802,6 +2809,16 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_impl(
28022809
c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
28032810
std::vector<at::Tensor>& tensors,
28042811
const AllreduceOptions& opts) {
2812+
if (intraNodeComm_ != nullptr && tensors.size() == 1 &&
2813+
opts.reduceOp == ReduceOp::SUM) {
2814+
using namespace intra_node_comm;
2815+
auto algo = intraNodeComm_->selectAllReduceAlgo(tensors[0]);
2816+
if (algo != intra_node_comm::AllReduceAlgo::NONE) {
2817+
intraNodeComm_->allReduce(tensors[0], algo);
2818+
return c10::make_intrusive<IntraNodeCommWork>();
2819+
}
2820+
}
2821+
28052822
check_gpu_tensors_different_devices(tensors);
28062823

28072824
// @lint-ignore CLANGTIDY

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <torch/csrc/distributed/c10d/Backend.hpp>
1414
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
1515
#include <torch/csrc/distributed/c10d/Store.hpp>
16+
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
1617

1718
#include <ATen/DynamicLibrary.h>
1819
#include <ATen/cuda/CUDAContext.h>
@@ -546,6 +547,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
546547
// Provide an API for users to define their own ways to store NCCL debug info.
547548
void registerDebugInfoWriter(std::unique_ptr<DebugInfoWriter> writer);
548549

550+
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
551+
549552
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
550553
// instead of relying on ProcessGroupNCCL destructor.
551554
void abort(c10::optional<std::string> abortReason = c10::nullopt);
@@ -940,6 +943,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
940943
std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr;
941944

942945
size_t uid_;
946+
947+
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> intraNodeComm_;
943948
};
944949

945950
TORCH_API std::string dump_nccl_trace();

torch/csrc/distributed/c10d/init.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#ifdef USE_C10D_NCCL
2222
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
2323
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
24+
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
2425
#endif
2526

2627
#ifdef USE_C10D_MPI
@@ -2328,6 +2329,10 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
23282329
"perform_nocolor_split",
23292330
&::c10d::ProcessGroupNCCL::performNocolorSplit);
23302331

2332+
module.def(
2333+
"_get_intra_node_comm_usage_counter",
2334+
&::c10d::intra_node_comm::getIntraNodeCommUsageCounter);
2335+
23312336
#ifdef NCCL_HAS_COMM_CTA_CGA
23322337
py::class_<ncclConfig_t>(
23332338
processGroupNCCL,

0 commit comments

Comments
 (0)